Byte-lingua-code / m1_compression /hybrid_arithmetic_coder.py
2ira's picture
offline_compression_graph_code
72c0672 verified
import torch
from m1_compression import utils
import math
import numpy as np
from typing import List, Tuple, Callable, Any, Dict, Optional
import logging
from m1_compression.batched_arithmetic_coder import (
BatchedArithmeticEncoder,
)
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger()
class CPUArithmeticEncoder(BatchedArithmeticEncoder):
def __init__(self, base: int, precision: int):
super().__init__(base=base, precision=precision)
def batched_encode(
self,
gathered_cdfs: torch.Tensor, # [B, T, 2]
symbols: torch.Tensor,
lengths: Optional[torch.Tensor] = None,
return_num_padded_bits: bool = False
) -> Tuple[List[bytes], List[int]]:
raise NotImplementedError("CPUArithmeticEncoder does not support batched_encode")
def incremental_batched_encode(
self,
gathered_cdfs: torch.Tensor, # [B, T, 2]
vocab_size: int,
lengths: Optional[torch.Tensor] = None,
bit_threshold: Optional[int] = None,
force_padding_to_threshold: bool = False,
return_num_padded_bits: bool = False
) -> Tuple[List[bytes], List[int]] | Tuple[List[bytes], List[int], List[int]]:
"""
Incrementally encode symbols with early stopping when bit threshold is exceeded.
Args:
pdf: [B, T, V] probability distributions
symbols: [B, T] symbols to encode
lengths: [B] length of each sequence (optional)
bit_threshold: Stop encoding when any sequence exceeds this many bits
force_padding_to_threshold: Force padding to threshold even if bit threshold is not exceeded
return_num_padded_bits: Whether to return padding information
Returns:
final_compressed_bytes: List[bytes] - final compressed result for each sequence
stopped_at_step: List[int] - step where each sequence stopped (-1 if completed normally)
final_num_padded_bits: List[int] - padding info (only if return_num_padded_bits=True)
"""
B, T, _ = gathered_cdfs.shape
device = gathered_cdfs.device
if lengths is None:
lengths = torch.full((B,), T, dtype=torch.int64, device=device)
lengths = torch.clamp(lengths, min=0, max=T)
# Initialize arithmetic coding state
low = torch.zeros((B,), dtype=torch.int64, device=device)
high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device)
num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device)
# Initialize bit buffer
digits_sym = math.ceil(math.log(vocab_size, self._base))
max_digits = self._precision + 2 * T * digits_sym
bits_buffer = torch.empty(B * max_digits, dtype=torch.int32, device=device)
buf_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits
base_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits
# Pre-allocate temporary buffers (avoid cloning at each step)
temp_bits_buffer = torch.empty_like(bits_buffer)
temp_buf_offsets = torch.empty_like(buf_offsets)
temp_num_carry_digits = torch.empty_like(num_carry_digits)
# Track final results for each sequence - save buffer states, not bytes
final_buffer = torch.empty_like(bits_buffer)
final_buffer_ends = torch.zeros(B, dtype=torch.int32, device=device)
final_num_padded_bits = [None] * B
stopped_at_step = [-1] * B # -1 means completed normally
# Track which sequences are still active
active_sequences = torch.ones(B, dtype=torch.bool, device=device)
# Keep track of previous step's finalized buffer state for threshold logic
prev_finalized_buffer = torch.empty_like(bits_buffer)
prev_finalized_ends = torch.zeros_like(buf_offsets)
for t in range(T):
valid = (t < lengths) & active_sequences
if not valid.any():
break # All sequences completed or stopped
low_valid = low[valid]
high_valid = high[valid]
width_valid = high_valid - low_valid + 1
old_low = low.clone()
low[valid] = low_valid + (gathered_cdfs[valid, t, 0] * width_valid).to(torch.int64)
high[valid] = low_valid + (gathered_cdfs[valid, t, 1] * width_valid).to(torch.int64) - 1
# Flush digits and update buffers
(low, high, bits_buffer, buf_offsets, num_carry_digits, _) = self.flush_matching_digits(
low, high, old_low,
encoding=True,
bits_buffer=bits_buffer,
buf_offsets=buf_offsets,
num_carry_digits=num_carry_digits,
current_code_in_int=None,
_next_digit=None,
valid=valid
)
(low, high, num_carry_digits, _) = self.flush_carry_digits(
low, high,
encoding=True,
num_carry_digits=num_carry_digits,
current_code_in_int=None,
_next_digit=None,
valid=valid
)
# Check if we need to compute results this step (if bit threshold checking or final step)
need_check_threshold = bit_threshold is not None and active_sequences.any()
some_seq_finished = ((t + 1 >= lengths) & active_sequences).any()
if need_check_threshold or some_seq_finished:
# Simulate finalization at this step using pre-allocated buffers
temp_bits_buffer.copy_(bits_buffer, True)
temp_buf_offsets.copy_(buf_offsets, True)
temp_num_carry_digits.copy_(num_carry_digits, True)
# Add final digit for all sequences (simulating termination)
temp_bits_buffer[temp_buf_offsets] = (low // self._base_to_pm1).to(torch.int32)
temp_buf_offsets += 1
# Handle remaining carry digits for all sequences
carry_sel = (temp_num_carry_digits > 0).nonzero(as_tuple=False).flatten()
if carry_sel.numel():
carry_digit = self._base - 1
rep_cnt = temp_num_carry_digits[carry_sel]
repeats_max = rep_cnt.max()
grid = torch.arange(repeats_max, device=rep_cnt.device).expand(carry_sel.size(0), repeats_max)
mask_rep = grid < rep_cnt.unsqueeze(1)
start_pos = temp_buf_offsets[carry_sel]
target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep]
temp_bits_buffer[target_pos] = carry_digit
temp_buf_offsets.index_add_(0, carry_sel, rep_cnt)
temp_num_carry_digits[carry_sel] = 0
# Check bit threshold and identify newly stopped sequences
if need_check_threshold:
current_bit_counts = self._get_bit_counts(temp_buf_offsets, base_offsets)
exceeds_threshold = (current_bit_counts > bit_threshold) & active_sequences
if exceeds_threshold.any():
stopped_indices = exceeds_threshold.nonzero(as_tuple=False).flatten()
for idx in stopped_indices.cpu().tolist(): # Only move indices to CPU
active_sequences[idx] = False
stopped_at_step[idx] = t
# Save the result from PREVIOUS step (before exceeding threshold)
final_buffer_ends[idx] = prev_finalized_ends[idx]
offset_start = idx * max_digits
offset_end = prev_finalized_ends[idx]
final_buffer[offset_start:offset_end].copy_(prev_finalized_buffer[offset_start:offset_end])
# If final step, all remaining active sequences need results
is_final_step = (t + 1 >= lengths) & active_sequences
if is_final_step.any():
final_step_indices = is_final_step.nonzero(as_tuple=False).flatten()
for idx in final_step_indices.cpu().tolist():
active_sequences[idx] = False
stopped_at_step[idx] = t + 1
# Save current step result for sequences that completed normally
final_buffer_ends[idx] = temp_buf_offsets[idx]
# Copy the finalized bits to main buffer for this sequence
offset_start = idx * max_digits
offset_end = temp_buf_offsets[idx]
final_buffer[offset_start:offset_end].copy_(temp_bits_buffer[offset_start:offset_end])
# Update previous finalized buffer state for next iteration
if need_check_threshold:
prev_finalized_buffer.copy_(temp_bits_buffer)
prev_finalized_ends.copy_(temp_buf_offsets)
# Convert buffer states to compressed bytes at the very end
final_compressed_bytes = []
for idx in range(B):
offset_start = idx * max_digits
offset_end = final_buffer_ends[idx]
bits_list = final_buffer[offset_start:offset_end].cpu().tolist()
bitstr = "".join(map(str, bits_list))
if force_padding_to_threshold:
comp_bytes, num_padded = utils.bits_to_bytes_padding_to_threshold(bitstr, bit_threshold)
else:
comp_bytes, num_padded = utils.bits_to_bytes(bitstr)
final_compressed_bytes.append(comp_bytes)
if return_num_padded_bits:
final_num_padded_bits[idx] = num_padded
if return_num_padded_bits:
return final_compressed_bytes, stopped_at_step, final_num_padded_bits
else:
return final_compressed_bytes, stopped_at_step