| import torch |
| from m1_compression import utils |
| import math |
| import numpy as np |
| from typing import List, Tuple, Callable, Any, Dict, Optional |
| import logging |
| from m1_compression.batched_arithmetic_coder import ( |
| BatchedArithmeticEncoder, |
| ) |
|
|
| logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger() |
|
|
| class CPUArithmeticEncoder(BatchedArithmeticEncoder): |
| def __init__(self, base: int, precision: int): |
| super().__init__(base=base, precision=precision) |
|
|
| def batched_encode( |
| self, |
| gathered_cdfs: torch.Tensor, |
| symbols: torch.Tensor, |
| lengths: Optional[torch.Tensor] = None, |
| return_num_padded_bits: bool = False |
| ) -> Tuple[List[bytes], List[int]]: |
| raise NotImplementedError("CPUArithmeticEncoder does not support batched_encode") |
|
|
| def incremental_batched_encode( |
| self, |
| gathered_cdfs: torch.Tensor, |
| vocab_size: int, |
| lengths: Optional[torch.Tensor] = None, |
| bit_threshold: Optional[int] = None, |
| force_padding_to_threshold: bool = False, |
| return_num_padded_bits: bool = False |
| ) -> Tuple[List[bytes], List[int]] | Tuple[List[bytes], List[int], List[int]]: |
| """ |
| Incrementally encode symbols with early stopping when bit threshold is exceeded. |
| |
| Args: |
| pdf: [B, T, V] probability distributions |
| symbols: [B, T] symbols to encode |
| lengths: [B] length of each sequence (optional) |
| bit_threshold: Stop encoding when any sequence exceeds this many bits |
| force_padding_to_threshold: Force padding to threshold even if bit threshold is not exceeded |
| return_num_padded_bits: Whether to return padding information |
| |
| Returns: |
| final_compressed_bytes: List[bytes] - final compressed result for each sequence |
| stopped_at_step: List[int] - step where each sequence stopped (-1 if completed normally) |
| final_num_padded_bits: List[int] - padding info (only if return_num_padded_bits=True) |
| """ |
| B, T, _ = gathered_cdfs.shape |
| device = gathered_cdfs.device |
| |
| if lengths is None: |
| lengths = torch.full((B,), T, dtype=torch.int64, device=device) |
| |
| lengths = torch.clamp(lengths, min=0, max=T) |
| |
| |
| low = torch.zeros((B,), dtype=torch.int64, device=device) |
| high = torch.full((B,), int(self._base**self._precision) - 1, dtype=torch.int64, device=device) |
| num_carry_digits = torch.zeros((B,), dtype=torch.int32, device=device) |
| |
| |
| digits_sym = math.ceil(math.log(vocab_size, self._base)) |
| max_digits = self._precision + 2 * T * digits_sym |
| bits_buffer = torch.empty(B * max_digits, dtype=torch.int32, device=device) |
| buf_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits |
|
|
| base_offsets = torch.arange(B, device=device, dtype=torch.int32) * max_digits |
| |
| |
| temp_bits_buffer = torch.empty_like(bits_buffer) |
| temp_buf_offsets = torch.empty_like(buf_offsets) |
| temp_num_carry_digits = torch.empty_like(num_carry_digits) |
| |
| |
| final_buffer = torch.empty_like(bits_buffer) |
| final_buffer_ends = torch.zeros(B, dtype=torch.int32, device=device) |
| final_num_padded_bits = [None] * B |
| stopped_at_step = [-1] * B |
| |
| |
| active_sequences = torch.ones(B, dtype=torch.bool, device=device) |
| |
| |
| prev_finalized_buffer = torch.empty_like(bits_buffer) |
| prev_finalized_ends = torch.zeros_like(buf_offsets) |
| |
| for t in range(T): |
| valid = (t < lengths) & active_sequences |
| |
| if not valid.any(): |
| break |
| |
| low_valid = low[valid] |
| high_valid = high[valid] |
| width_valid = high_valid - low_valid + 1 |
| |
| old_low = low.clone() |
| low[valid] = low_valid + (gathered_cdfs[valid, t, 0] * width_valid).to(torch.int64) |
| high[valid] = low_valid + (gathered_cdfs[valid, t, 1] * width_valid).to(torch.int64) - 1 |
| |
| |
| (low, high, bits_buffer, buf_offsets, num_carry_digits, _) = self.flush_matching_digits( |
| low, high, old_low, |
| encoding=True, |
| bits_buffer=bits_buffer, |
| buf_offsets=buf_offsets, |
| num_carry_digits=num_carry_digits, |
| current_code_in_int=None, |
| _next_digit=None, |
| valid=valid |
| ) |
| |
| (low, high, num_carry_digits, _) = self.flush_carry_digits( |
| low, high, |
| encoding=True, |
| num_carry_digits=num_carry_digits, |
| current_code_in_int=None, |
| _next_digit=None, |
| valid=valid |
| ) |
| |
| |
| need_check_threshold = bit_threshold is not None and active_sequences.any() |
| some_seq_finished = ((t + 1 >= lengths) & active_sequences).any() |
| |
| if need_check_threshold or some_seq_finished: |
| |
| temp_bits_buffer.copy_(bits_buffer, True) |
| temp_buf_offsets.copy_(buf_offsets, True) |
| temp_num_carry_digits.copy_(num_carry_digits, True) |
| |
| |
| temp_bits_buffer[temp_buf_offsets] = (low // self._base_to_pm1).to(torch.int32) |
| temp_buf_offsets += 1 |
| |
| |
| carry_sel = (temp_num_carry_digits > 0).nonzero(as_tuple=False).flatten() |
| if carry_sel.numel(): |
| carry_digit = self._base - 1 |
| rep_cnt = temp_num_carry_digits[carry_sel] |
| repeats_max = rep_cnt.max() |
| grid = torch.arange(repeats_max, device=rep_cnt.device).expand(carry_sel.size(0), repeats_max) |
| mask_rep = grid < rep_cnt.unsqueeze(1) |
| |
| start_pos = temp_buf_offsets[carry_sel] |
| target_pos = (start_pos.unsqueeze(1) + grid)[mask_rep] |
| temp_bits_buffer[target_pos] = carry_digit |
| temp_buf_offsets.index_add_(0, carry_sel, rep_cnt) |
| temp_num_carry_digits[carry_sel] = 0 |
| |
| |
| if need_check_threshold: |
| current_bit_counts = self._get_bit_counts(temp_buf_offsets, base_offsets) |
| exceeds_threshold = (current_bit_counts > bit_threshold) & active_sequences |
| |
| if exceeds_threshold.any(): |
| stopped_indices = exceeds_threshold.nonzero(as_tuple=False).flatten() |
| for idx in stopped_indices.cpu().tolist(): |
| active_sequences[idx] = False |
| stopped_at_step[idx] = t |
| |
| final_buffer_ends[idx] = prev_finalized_ends[idx] |
| offset_start = idx * max_digits |
| offset_end = prev_finalized_ends[idx] |
| final_buffer[offset_start:offset_end].copy_(prev_finalized_buffer[offset_start:offset_end]) |
| |
| |
| is_final_step = (t + 1 >= lengths) & active_sequences |
| if is_final_step.any(): |
| final_step_indices = is_final_step.nonzero(as_tuple=False).flatten() |
| for idx in final_step_indices.cpu().tolist(): |
| active_sequences[idx] = False |
| stopped_at_step[idx] = t + 1 |
| |
| final_buffer_ends[idx] = temp_buf_offsets[idx] |
| |
| offset_start = idx * max_digits |
| offset_end = temp_buf_offsets[idx] |
| final_buffer[offset_start:offset_end].copy_(temp_bits_buffer[offset_start:offset_end]) |
| |
| |
| if need_check_threshold: |
| prev_finalized_buffer.copy_(temp_bits_buffer) |
| prev_finalized_ends.copy_(temp_buf_offsets) |
| |
| |
| final_compressed_bytes = [] |
| |
| for idx in range(B): |
| offset_start = idx * max_digits |
| offset_end = final_buffer_ends[idx] |
| bits_list = final_buffer[offset_start:offset_end].cpu().tolist() |
| bitstr = "".join(map(str, bits_list)) |
| if force_padding_to_threshold: |
| comp_bytes, num_padded = utils.bits_to_bytes_padding_to_threshold(bitstr, bit_threshold) |
| else: |
| comp_bytes, num_padded = utils.bits_to_bytes(bitstr) |
| final_compressed_bytes.append(comp_bytes) |
| if return_num_padded_bits: |
| final_num_padded_bits[idx] = num_padded |
| |
| if return_num_padded_bits: |
| return final_compressed_bytes, stopped_at_step, final_num_padded_bits |
| else: |
| return final_compressed_bytes, stopped_at_step |
|
|