| import numpy as np |
| from typing import Tuple, Callable |
| from m1_compression import arithmetic_coder |
| from m1_compression import utils |
| import torch |
| import logging |
| from pathlib import Path |
| import time |
| from apps.main.transformer import LMTransformer, LMTransformerArgs |
| from lingua.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints |
| from apps.main.generate import ( |
| load_consolidated_model_and_tokenizer, |
| ) |
| from lingua.args import dataclass_from_dict |
| from lingua.checkpoint import CONSOLIDATE_NAME |
| from omegaconf import OmegaConf |
| logger = logging.getLogger() |
|
|
| ALPHABET_SIZE = 256 |
| |
| ARITHMETIC_CODER_BASE = 2 |
| |
| ARITHMETIC_CODER_PRECISION = 32 |
| WINDOW_SIZE = 32 |
|
|
| def load_m1_model_and_tokenizer(consolidated_path: str): |
| """ |
| Args: |
| consolidated_path (str): 模型检查点的路径。 |
| """ |
| |
| ckpt_dir = Path(consolidated_path) |
| if ( |
| Path(ckpt_dir).exists() |
| and (Path(ckpt_dir) / "params.json").exists() |
| and next(Path(ckpt_dir).glob("*.pth"), None) is not None |
| ): |
| consolidate_path = Path(ckpt_dir) |
| else: |
| consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER |
| if not consolidate_path.exists(): |
| consolidate_path = consolidate_checkpoints(ckpt_dir) |
|
|
| |
| consolidate_path = str(consolidate_path) |
| logger.info("Loading model") |
| model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer( |
| consolidate_path, |
| model_cls=LMTransformer, |
| model_args_cls=LMTransformerArgs, |
| ) |
| logger.info("Model loaded") |
| model.eval() |
| predict_fn = get_predict_fn(model, tokenizer) |
| return model, tokenizer, predict_fn |
|
|
| def load_m1_model_cpu(consolidated_path: str): |
| """ |
| Args: |
| consolidated_path (str): 模型检查点的路径。 |
| """ |
| |
| ckpt_dir = Path(consolidated_path) |
| if ( |
| Path(ckpt_dir).exists() |
| and (Path(ckpt_dir) / "params.json").exists() |
| and next(Path(ckpt_dir).glob("*.pth"), None) is not None |
| ): |
| consolidate_path = Path(ckpt_dir) |
| else: |
| consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER |
| if not consolidate_path.exists(): |
| consolidate_path = consolidate_checkpoints(ckpt_dir) |
|
|
| |
| consolidate_path = str(consolidate_path) |
| logger.info("Loading model") |
| ckpt_path = Path(consolidate_path) |
| config = ckpt_path / "params.json" |
| config = OmegaConf.load(config) |
|
|
| param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ |
| config.distributed.model_dtype |
| ] |
| model_args = dataclass_from_dict(LMTransformerArgs, config.model, strict=False) |
| model = LMTransformer(model_args) |
| st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True) |
| model.load_state_dict(st_dict["model"]) |
| model = model.eval() |
| for param in model.parameters(): |
| param.data = param.data.to(dtype=param_dtype) |
| logger.info("Model loaded") |
| model.eval() |
| return model |
|
|
| def get_predict_fn(model, tokenizer): |
| """ |
| return a function that takes a sequence of tokens and returns the probability distribution of the next token. |
| Args: |
| model: the model to use for prediction |
| tokenizer: the tokenizer to use for encoding/decoding |
| """ |
| def predict_fn(input_sequence: np.ndarray) -> np.ndarray: |
| """ |
| Args: |
| input_sequence (np.ndarray): 输入序列,形状为 (batch_size, seq_len)。 |
| Returns: |
| np.ndarray: 每个 token 的概率分布,形状为 (batch_size, seq_len, vocab_size)。 |
| """ |
|
|
| |
| if input_sequence.size == 0: |
| initial_probs = np.ones((1, 1, 256), dtype=np.float32) / 256 |
| return initial_probs |
|
|
| |
| input_tensor = torch.tensor(input_sequence, dtype=torch.long).cuda() |
| if input_tensor.dim() == 1: |
| input_tensor = input_tensor.unsqueeze(0) |
| with torch.no_grad(): |
| |
| logits = model(input_tensor) |
|
|
| logits = logits[..., :256] |
| logits = logits.float() |
| assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values." |
| probs = torch.softmax(logits, dim=-1) |
| probs = probs.float().cpu().numpy() |
|
|
| return probs |
| |
| return predict_fn |
|
|
|
|
| def m1_arithmetic_compress( |
| data: bytes, |
| predict_fn: Callable, |
| return_num_padded_bits: bool = True, |
| use_slow_lossless_compression: bool = True |
| ) -> bytes | tuple[bytes, int]: |
| """use language model to compress, return compressed bytes and padded bits""" |
| sequence_array = np.frombuffer(data, dtype=np.uint8) |
|
|
| if use_slow_lossless_compression: |
| probs = [] |
| for k in range(len(sequence_array)): |
| |
| input_seq = sequence_array[:k] |
| if input_seq.size == 0: |
| |
| current_probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE |
| else: |
| model_probs = predict_fn(input_seq[None]) |
| current_probs = model_probs[0, -1] |
| probs.append(current_probs) |
| probs = np.array(probs) |
| else: |
| if sequence_array.size == 0: |
| probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE |
| else: |
| full_probs = predict_fn(sequence_array[None])[0, ...] |
| probs = np.concatenate( |
| [ |
| np.ones(ALPHABET_SIZE, dtype=np.float32)[None] / ALPHABET_SIZE, |
| full_probs[:-1,...] |
| ], axis=0 |
| ) |
| |
| probs /= probs.sum(axis=-1, keepdims=True) |
|
|
| assert probs.shape[1] == ALPHABET_SIZE, "The shape of probs is not correct." |
| assert probs.shape[0] == len(sequence_array), "The shape of probs is not correct." |
| assert np.isclose(sum(probs[0]), 1, atol=1e-6), "The probs is not normalized." |
|
|
| output = [] |
| encoder = arithmetic_coder.Encoder( |
| base=ARITHMETIC_CODER_BASE, |
| precision=ARITHMETIC_CODER_PRECISION, |
| output_fn=output.append |
| ) |
| for pdf, symbol in zip(probs, sequence_array): |
| encoder.encode(utils.normalize_pdf_for_arithmetic_coding(pdf), symbol) |
| encoder.terminate() |
|
|
| compressed_bits = ''.join(map(str, output)) |
| |
| compressed_bytes, num_padded_bits = utils.bits_to_bytes(compressed_bits) |
|
|
| if return_num_padded_bits: |
| return compressed_bytes, num_padded_bits |
| return compressed_bytes |
| |
| def m1_arithmetic_decompress( |
| compressed: bytes, |
| predict_fn: Callable, |
| num_padded_bits: int, |
| length: int, |
| ) -> np.ndarray: |
| bits = utils.bytes_to_bits(compressed, num_padded_bits=num_padded_bits) |
| data_iter = iter(bits) |
| def _input_fn() -> int | None: |
| try: |
| return int(next(data_iter)) |
| except StopIteration: |
| return None |
| |
| decoder = arithmetic_coder.Decoder( |
| base=ARITHMETIC_CODER_BASE, |
| precision=ARITHMETIC_CODER_PRECISION, |
| input_fn=_input_fn |
| ) |
| sequence_array = np.empty((0,), dtype=np.uint8) |
| for k in range(length): |
| |
| if k == 0: |
| |
| input_seq = np.empty((0,), dtype=np.uint8) |
| current_probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE |
| else: |
| input_seq = sequence_array |
| model_probs = predict_fn(input_seq[None]) |
| current_probs = model_probs[0, -1] |
|
|
| |
| current_probs /= current_probs.sum() |
| |
| |
| token = decoder.decode(utils.normalize_pdf_for_arithmetic_coding(current_probs)) |
| |
| |
| sequence_array = np.concatenate([sequence_array, np.array([token], dtype=np.uint8)]) |
| |
| |
| return sequence_array.tobytes() |
|
|
| def m1_arithmetic_compress_with_windows( |
| data: bytes, |
| predict_fn: Callable, |
| window_bit_size=WINDOW_SIZE, |
| return_num_padded_bits: bool = True, |
| use_slow_lossless_compression: bool = True |
| ) -> bytes | Tuple[bytes, list[int], list[int]]: |
| compressed_windows = [] |
| num_padded_bits_per_window = [] |
| original_lengths_per_window = [] |
| current_window = [] |
|
|
| for byte in data: |
| |
| current_window.append(byte) |
| |
| compressed, num_padded = m1_arithmetic_compress( |
| bytes(current_window), |
| predict_fn=predict_fn, |
| return_num_padded_bits=True, |
| use_slow_lossless_compression=use_slow_lossless_compression, |
| ) |
| |
| compressed_bits = ''.join(map(str, utils.bytes_to_bits(compressed, num_padded))) |
| |
| if len(compressed_bits) > window_bit_size: |
| |
| current_window.pop() |
| |
| compressed, num_padded = m1_arithmetic_compress( |
| bytes(current_window), |
| predict_fn=predict_fn, |
| return_num_padded_bits=True, |
| use_slow_lossless_compression=use_slow_lossless_compression, |
| ) |
| compressed_bits = ''.join(map(str, utils.bytes_to_bits(compressed, num_padded))) |
| compressed_windows.append(compressed) |
| num_padded_bits_per_window.append(num_padded) |
| original_lengths_per_window.append(len(current_window)) |
| |
| |
| |
| |
| |
| |
| |
| |
| current_window = [byte] |
|
|
| |
| if current_window: |
| compressed, num_padded = m1_arithmetic_compress( |
| bytes(current_window), |
| predict_fn=predict_fn, |
| return_num_padded_bits=True, |
| use_slow_lossless_compression=use_slow_lossless_compression, |
| ) |
| compressed_windows.append(compressed) |
| num_padded_bits_per_window.append(num_padded) |
| original_lengths_per_window.append(len(current_window)) |
|
|
| all_compressed_bytes = b''.join(compressed_windows) |
|
|
| if return_num_padded_bits: |
| return all_compressed_bytes, num_padded_bits_per_window, original_lengths_per_window |
| return all_compressed_bytes |
|
|
|
|
| def m1_arithmetic_decompress_with_windows( |
| compressed: bytes, |
| predict_fn: Callable, |
| window_bit_size, |
| num_padded_bits_per_window: list[int], |
| original_lengths_per_window: list[int] |
| ) -> bytes: |
| decoded_bytes = b'' |
| start = 0 |
| bitstream = utils.bytes_to_bits(compressed) |
| for num_padded, length in zip(num_padded_bits_per_window, original_lengths_per_window): |
| |
| window_bitstream = bitstream[start:start + window_bit_size] |
| |
| window_compressed, _ = utils.bits_to_bytes(window_bitstream) |
| print(f"当前窗口压缩数据: {window_compressed}") |
| |
| decoded_window = m1_arithmetic_decompress( |
| window_compressed, |
| predict_fn, |
| num_padded, |
| length, |
| ) |
| print(f"解压缩窗口: {decoded_window}") |
| print(f"解压缩窗口的长度: {len(decoded_window)}") |
| decoded_bytes += decoded_window |
| |
| start += window_bit_size |
| return decoded_bytes |
|
|
|
|
| def test_equal_window_compression(sequence: str): |
| print(f"测试序列: {sequence}") |
| model, tokenizer, predict_fn = load_m1_model_and_tokenizer(consolidated_path = "/mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/m1_checkpoints/m1_1M_steps10k_bs32_seqlen2048/checkpoints/0000010000") |
| original_bytes = tokenizer.encode(sequence) |
| print(f"token 字节: {original_bytes}") |
| original_bytes = bytes(original_bytes) |
| print(f"转为字节流: {original_bytes}") |
| compressed_data = m1_arithmetic_compress_with_windows( |
| original_bytes, |
| predict_fn=predict_fn, |
| window_bit_size=WINDOW_SIZE, |
| return_num_padded_bits=True, |
| use_slow_lossless_compression=False |
| ) |
| compressed_bytes, num_padded_bits_per_window, original_lengths_per_window = compressed_data |
| print(f"压缩后的字节: {compressed_bytes}") |
| print(f"窗口填充位数数组: {num_padded_bits_per_window}") |
| print(f"窗口原始长度数组: {original_lengths_per_window}") |
| try: |
| decoded_bytes = m1_arithmetic_decompress_with_windows( |
| compressed_bytes, |
| predict_fn, |
| WINDOW_SIZE, |
| num_padded_bits_per_window, |
| original_lengths_per_window |
| ) |
| decoded_sequence = decoded_bytes.decode('utf-8') |
|
|
| assert decoded_sequence == sequence, f"解码失败:原始={sequence}, 解码={decoded_sequence}" |
| except: |
| print(f"解码失败") |
|
|
| compression_ratio = len(original_bytes) / len(compressed_bytes) |
| print(f"原始大小: {len(original_bytes)} bytes") |
| print(f"压缩后大小: {len(compressed_bytes)} bytes") |
| print(f"各窗口填充位数: {num_padded_bits_per_window}") |
| print(f"各窗口原始长度: {original_lengths_per_window}") |
| print(f"压缩率: {compression_ratio:.2f}x") |
| |
|
|
|
|
| if __name__ == "__main__": |
| test_sequences = [ |
| "if month % 2 : return 30", |
| "def __init__(self, name): self.name = name", |
| "import pandas as pd\n import matplotlib.pyplot as plt", |
| "import torch", |
| "import torch\nimport torch.nn as nn\nimport torch.nn.functional as nn" |
| "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F" * 100, |
| ] |
| |
| |
| |
| |
| |
| |
| for seq in test_sequences: |
| print(f"=== m1语言模型 Equal_Window 算术压缩 ===") |
| start_time = time.time() |
| test_equal_window_compression(seq) |
| end_time = time.time() |
| print(f"用时: {end_time - start_time} 秒") |
|
|