| from typing import List, Tuple |
| import itertools |
| import math |
|
|
| def _pack_rank_ids(buf: List[int], rank_bitlength: int) -> List[int]: |
| per_b = 8 // rank_bitlength |
| mask = (1 << rank_bitlength) - 1 |
| out_b = [] |
| it = iter(buf) |
|
|
| while True: |
| chunk = list(itertools.islice(it, per_b)) |
| if not chunk: |
| break |
| byte_val = 0 |
| for p, idx in enumerate(chunk): |
| byte_val |= (idx & mask) << (p * rank_bitlength) |
| out_b.append(byte_val) |
| return out_b |
|
|
| def _unpack_rank_ids(payload: List[int], run_len: int, rank_bitlength: int): |
| mask = (1 << rank_bitlength) - 1 |
|
|
| byte_iter = iter(payload) |
| cur_byte = next(byte_iter) |
| filled = 8 |
|
|
| for _ in range(run_len): |
| if filled == 0: |
| cur_byte = next(byte_iter) |
| filled = 8 |
| rank_id = cur_byte & mask |
| cur_byte >>= rank_bitlength |
| filled -= rank_bitlength |
| yield rank_id |
|
|
| class SimpleAdaptiveRankCodec: |
| def __init__( |
| self, |
| top_k: int = 4, |
| tau: float = 0.5, |
| min_run: int = 3, |
| max_run: int = 255, |
| sentinel_rle: int = 256, |
| sentinel_rank_run: int = 257, |
| ): |
| self.top_k = top_k |
| self.tau = tau |
| self.min_run = min_run |
| self.max_run = max_run |
| self.raw_byte_offset = 256 |
|
|
| self.rank_bitlength = max(1, (top_k - 1).bit_length()) |
| assert self.rank_bitlength <= 8 and 8 % self.rank_bitlength == 0, ( |
| f"rank_bitlength must be between 1 and 8 and must divide 8, got {self.rank_bitlength}" |
| f"top_k: {top_k}" |
| ) |
| self.ranks_per_byte = 8 // self.rank_bitlength |
| self.sentinel_rle = sentinel_rle |
| self.sentinel_rank_run = sentinel_rank_run |
|
|
| def encode_window( |
| self, |
| tokens: List[int], |
| repeat_probs: List[float], |
| ranks: List[int], |
| ) -> List[int]: |
| """Return a list of ints: raw bytes 0-255 and sentinel events ≥256.""" |
| assert len(tokens) == len(repeat_probs) == len(ranks) |
|
|
| rank_buf: List[int] = [] |
| out: List[int] = [tokens[0]] |
| i, n = 1, len(tokens) |
|
|
| def flush_rank_buf(): |
| if not rank_buf: |
| return |
| out.append(self.sentinel_rank_run) |
| out.append(len(rank_buf)) |
| out.extend(_pack_rank_ids(rank_buf, self.rank_bitlength)) |
| rank_buf.clear() |
|
|
| while i < n: |
| tok = tokens[i] |
|
|
| |
| run = 1 |
| while (i + run < n and |
| tokens[i + run] == tok and |
| repeat_probs[i + run] >= self.tau): |
| run += 1 |
| if run >= self.min_run: |
| flush_rank_buf() |
| out.extend([self.sentinel_rle, run, tok]) |
| i += run |
| continue |
|
|
| if ranks[i] < self.top_k: |
| rank_buf.append(ranks[i]) |
| else: |
| |
| |
| flush_rank_buf() |
| out.append(tok) |
| i += 1 |
|
|
| flush_rank_buf() |
| return out |
|
|
| def encoding_to_pseudo_bytes(self, enc: list[int]) -> list[int]: |
| |
| |
| out: list[int] = [] |
| i = 0 |
| while i < len(enc): |
| tok = enc[i] |
| i += 1 |
|
|
| if tok < self.raw_byte_offset: |
| out.append(tok) |
|
|
| elif tok == self.sentinel_rle: |
| run = enc[i] |
| raw = enc[i+1] |
| i += 2 |
| run = min(run, self.max_run) |
| |
| out.extend([self.raw_byte_offset + self.raw_byte_offset - run, raw]) |
| elif tok == self.sentinel_rank_run: |
| length = enc[i] |
| i += 1 |
| n_bytes = math.ceil(length / self.ranks_per_byte) |
| for _ in range(n_bytes): |
| pb = enc[i] + self.raw_byte_offset |
| out.append(pb) |
| i += 1 |
| else: |
| raise ValueError(f"unknown token {tok}") |
| return out |
|
|
| def pseudo_bytes_to_encoding(self, pb: list[int], original_encoding: list[int]) -> list[int]: |
| |
| |
| |
| raise NotImplementedError("Not implemented") |
|
|
| def decode_window( |
| self, |
| stream: List[int], |
| original_len: int, |
| topk_symbols: List[List[int]], |
| ) -> List[int]: |
| """ |
| `topk_symbols[pos][idx]` must give the byte value (0-255) that |
| corresponds to rank `idx` at position `pos`, e.g. recomputed from |
| the helper LM during decoding. |
| """ |
| out: List[int] = [] |
| |
| i = 0 |
| |
| pos = 0 |
|
|
| while pos < original_len: |
| tok = stream[i] |
| i += 1 |
|
|
| if tok < 256: |
| out.append(tok) |
| pos += 1 |
|
|
| elif tok == self.sentinel_rle: |
| run_len = stream[i] |
| raw = stream[i+1] |
| i += 2 |
| out.extend([raw] * run_len) |
| pos += run_len |
|
|
| elif tok == self.sentinel_rank_run: |
| run_len = stream[i] |
| i += 1 |
| bytes_needed = math.ceil(run_len / self.ranks_per_byte) |
| payload = stream[i: i + bytes_needed] |
| i += bytes_needed |
| for rank_id in _unpack_rank_ids(payload, run_len, self.rank_bitlength): |
| sym = topk_symbols[pos][rank_id] |
| out.append(sym) |
| pos += 1 |
| else: |
| raise ValueError(f"Unknown sentinel {tok}") |
|
|
| return out[:original_len] |
|
|
| if __name__ == "__main__": |
| import torch, random |
| random.seed(0) |
| T, K = 384, 13 |
| tokens = torch.randint(0, 32, (T,)).tolist() |
| repeat_probs = torch.rand(T).tolist() |
| ranks = torch.randint(0, K + 5, (T,)).tolist() |
| ranks = [r if r < K else K for r in ranks] |
|
|
| |
| topk = [[tokens[t]] * K for t in range(T)] |
|
|
| codec = SimpleAdaptiveRankCodec(top_k=K, tau=0.00) |
|
|
| enc = codec.encode_window(tokens, repeat_probs, ranks) |
| dec = codec.decode_window(enc, T, topk) |
|
|
| print(f"raw={T} encoded={len(enc)} ratio={len(enc)/T:.2f}") |
| assert dec == tokens |
| print("✓ window-enc-dec round-trip passes") |
|
|