2ira's picture
offline_compression_graph_code
72c0672 verified
import numpy as np
from typing import Tuple, Callable
from m1_compression import arithmetic_coder
from m1_compression import utils
import torch
import logging
from pathlib import Path
import time
from apps.main.transformer import LMTransformer, LMTransformerArgs
from lingua.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from apps.main.generate import (
load_consolidated_model_and_tokenizer,
)
from lingua.args import dataclass_from_dict
from lingua.checkpoint import CONSOLIDATE_NAME
from omegaconf import OmegaConf
logger = logging.getLogger()
ALPHABET_SIZE = 256
# Base 2 means that the coder writes bits.
ARITHMETIC_CODER_BASE = 2
# Precision 16 implies 16 bit arithmetic, in the original paper it is 32.
ARITHMETIC_CODER_PRECISION = 32
WINDOW_SIZE = 32 # 窗口大小,以位为单位,可根据需求调整
def load_m1_model_and_tokenizer(consolidated_path: str):
"""
Args:
consolidated_path (str): 模型检查点的路径。
"""
# 加载配置文件
ckpt_dir = Path(consolidated_path)
if (
Path(ckpt_dir).exists()
and (Path(ckpt_dir) / "params.json").exists()
and next(Path(ckpt_dir).glob("*.pth"), None) is not None
):
consolidate_path = Path(ckpt_dir)
else:
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
if not consolidate_path.exists():
consolidate_path = consolidate_checkpoints(ckpt_dir)
# use api to load model
consolidate_path = str(consolidate_path)
logger.info("Loading model")
model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
consolidate_path,
model_cls=LMTransformer,
model_args_cls=LMTransformerArgs,
)
logger.info("Model loaded")
model.eval()
predict_fn = get_predict_fn(model, tokenizer)
return model, tokenizer, predict_fn
def load_m1_model_cpu(consolidated_path: str):
"""
Args:
consolidated_path (str): 模型检查点的路径。
"""
# 加载配置文件
ckpt_dir = Path(consolidated_path)
if (
Path(ckpt_dir).exists()
and (Path(ckpt_dir) / "params.json").exists()
and next(Path(ckpt_dir).glob("*.pth"), None) is not None
):
consolidate_path = Path(ckpt_dir)
else:
consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
if not consolidate_path.exists():
consolidate_path = consolidate_checkpoints(ckpt_dir)
# use api to load model
consolidate_path = str(consolidate_path)
logger.info("Loading model")
ckpt_path = Path(consolidate_path)
config = ckpt_path / "params.json"
config = OmegaConf.load(config)
param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
config.distributed.model_dtype
]
model_args = dataclass_from_dict(LMTransformerArgs, config.model, strict=False)
model = LMTransformer(model_args)
st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
model.load_state_dict(st_dict["model"])
model = model.eval()
for param in model.parameters():
param.data = param.data.to(dtype=param_dtype)
logger.info("Model loaded")
model.eval()
return model
def get_predict_fn(model, tokenizer):
"""
return a function that takes a sequence of tokens and returns the probability distribution of the next token.
Args:
model: the model to use for prediction
tokenizer: the tokenizer to use for encoding/decoding
"""
def predict_fn(input_sequence: np.ndarray) -> np.ndarray:
"""
Args:
input_sequence (np.ndarray): 输入序列,形状为 (batch_size, seq_len)。
Returns:
np.ndarray: 每个 token 的概率分布,形状为 (batch_size, seq_len, vocab_size)。
"""
# 如果输入的序列为空,返回模型平均概率分布---压缩和解压我必须要看见这个
if input_sequence.size == 0:
initial_probs = np.ones((1, 1, 256), dtype=np.float32) / 256 # 均匀分布
return initial_probs
# turn to torch tensor
input_tensor = torch.tensor(input_sequence, dtype=torch.long).cuda()
if input_tensor.dim() == 1:
input_tensor = input_tensor.unsqueeze(0)
with torch.no_grad():
# get logits
logits = model(input_tensor)
logits = logits[..., :256]
logits = logits.float()
assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values."
probs = torch.softmax(logits, dim=-1)
probs = probs.float().cpu().numpy()
return probs
return predict_fn
def m1_arithmetic_compress(
data: bytes,
predict_fn: Callable,
return_num_padded_bits: bool = True,
use_slow_lossless_compression: bool = True
) -> bytes | tuple[bytes, int]:
"""use language model to compress, return compressed bytes and padded bits"""
sequence_array = np.frombuffer(data, dtype=np.uint8)
if use_slow_lossless_compression:
probs = []
for k in range(len(sequence_array)): # k 表示已处理的 token 数(从 0 到 n-1)
# 输入前 k 个 token(初始 k=0 时为空序列),预测第 k+1 个 token 的概率
input_seq = sequence_array[:k] # 前 k 个 token(空序列当 k=0 时)
if input_seq.size == 0:
# 空输入时使用均匀分布(或模型预设的初始分布)
current_probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE
else:
model_probs = predict_fn(input_seq[None]) # 形状 (1, k, 256)
current_probs = model_probs[0, -1] # 提取最后一个位置(第 k 个位置,预测第 k+1 个 token)
probs.append(current_probs)
probs = np.array(probs)
else:
if sequence_array.size == 0:
probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE
else:
full_probs = predict_fn(sequence_array[None])[0, ...]
probs = np.concatenate(
[
np.ones(ALPHABET_SIZE, dtype=np.float32)[None] / ALPHABET_SIZE,
full_probs[:-1,...]
], axis=0
)
# print("[FAST] shape : {}".format(probs.shape))
probs /= probs.sum(axis=-1, keepdims=True)
assert probs.shape[1] == ALPHABET_SIZE, "The shape of probs is not correct."
assert probs.shape[0] == len(sequence_array), "The shape of probs is not correct."
assert np.isclose(sum(probs[0]), 1, atol=1e-6), "The probs is not normalized."
output = []
encoder = arithmetic_coder.Encoder(
base=ARITHMETIC_CODER_BASE,
precision=ARITHMETIC_CODER_PRECISION,
output_fn=output.append
)
for pdf, symbol in zip(probs, sequence_array):
encoder.encode(utils.normalize_pdf_for_arithmetic_coding(pdf), symbol)
encoder.terminate()
compressed_bits = ''.join(map(str, output))
## padding zero to turn the bitstream into bytes
compressed_bytes, num_padded_bits = utils.bits_to_bytes(compressed_bits)
if return_num_padded_bits:
return compressed_bytes, num_padded_bits
return compressed_bytes
def m1_arithmetic_decompress(
compressed: bytes,
predict_fn: Callable,
num_padded_bits: int,
length: int,
) -> np.ndarray:
bits = utils.bytes_to_bits(compressed, num_padded_bits=num_padded_bits)
data_iter = iter(bits)
def _input_fn() -> int | None:
try:
return int(next(data_iter))
except StopIteration:
return None
decoder = arithmetic_coder.Decoder(
base=ARITHMETIC_CODER_BASE,
precision=ARITHMETIC_CODER_PRECISION,
input_fn=_input_fn
)
sequence_array = np.empty((0,), dtype=np.uint8)
for k in range(length):
# 预测第 k+1 个 token 的概率:输入已解码的 k 个 token
if k == 0:
# 第一个 token:输入空序列,使用初始分布
input_seq = np.empty((0,), dtype=np.uint8)
current_probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE
else:
input_seq = sequence_array # 已解码的 k 个 token
model_probs = predict_fn(input_seq[None]) # 形状 (1, k, 256)
current_probs = model_probs[0, -1] # 提取最后一个位置的概率(下一个 token)
# 确保概率归一化(如果模型输出未归一化)
current_probs /= current_probs.sum()
#print("------ this range is ------",k)
#print("------current_probs-----------:", current_probs)
token = decoder.decode(utils.normalize_pdf_for_arithmetic_coding(current_probs))
## 直接append 出现了中间吧sequence_array类型扩大的情况,出现了前置很多0
#sequence_array = np.append(sequence_array, token)
sequence_array = np.concatenate([sequence_array, np.array([token], dtype=np.uint8)])
# print("解压缩后的 token IDs:", sequence_array)
# print("to byte is:", sequence_array.tobytes())
return sequence_array.tobytes()
def m1_arithmetic_compress_with_windows(
data: bytes,
predict_fn: Callable,
window_bit_size=WINDOW_SIZE,
return_num_padded_bits: bool = True,
use_slow_lossless_compression: bool = True
) -> bytes | Tuple[bytes, list[int], list[int]]:
compressed_windows = []
num_padded_bits_per_window = []
original_lengths_per_window = []
current_window = []
for byte in data:
# 压缩窗口中加入此字节
current_window.append(byte)
#进行压缩,得到压缩后的字节和填充位数
compressed, num_padded = m1_arithmetic_compress(
bytes(current_window),
predict_fn=predict_fn,
return_num_padded_bits=True,
use_slow_lossless_compression=use_slow_lossless_compression,
)
#计算压缩后的bitstream
compressed_bits = ''.join(map(str, utils.bytes_to_bits(compressed, num_padded)))
# 计算是否超过窗口大小,进行截断
if len(compressed_bits) > window_bit_size:
# print(f"oversize当前窗口压缩后的位数: {len(compressed_bits)}")
current_window.pop()
#得到上一个切分的窗口进行窗口内压缩并且保存
compressed, num_padded = m1_arithmetic_compress(
bytes(current_window),
predict_fn=predict_fn,
return_num_padded_bits=True,
use_slow_lossless_compression=use_slow_lossless_compression,
)
compressed_bits = ''.join(map(str, utils.bytes_to_bits(compressed, num_padded)))
compressed_windows.append(compressed)
num_padded_bits_per_window.append(num_padded)
original_lengths_per_window.append(len(current_window))
# print(f"当前窗口: {current_window}")
# print(f"当前窗口压缩长度: {len(compressed)}")
# print(f"当前窗口压缩后的位数: {len(compressed_bits)}")
# print(f"填充长度: {num_padded}")
# print(f"----------")
# print(f"当前compressed byte数据: {compressed}")
# print(f"当前bitstream: {compressed_bits}")
current_window = [byte]
# 处理最后一个窗口
if current_window:
compressed, num_padded = m1_arithmetic_compress(
bytes(current_window),
predict_fn=predict_fn,
return_num_padded_bits=True,
use_slow_lossless_compression=use_slow_lossless_compression,
)
compressed_windows.append(compressed)
num_padded_bits_per_window.append(num_padded)
original_lengths_per_window.append(len(current_window))
all_compressed_bytes = b''.join(compressed_windows)
if return_num_padded_bits:
return all_compressed_bytes, num_padded_bits_per_window, original_lengths_per_window
return all_compressed_bytes
def m1_arithmetic_decompress_with_windows(
compressed: bytes,
predict_fn: Callable,
window_bit_size,
num_padded_bits_per_window: list[int],
original_lengths_per_window: list[int]
) -> bytes:
decoded_bytes = b''
start = 0
bitstream = utils.bytes_to_bits(compressed)
for num_padded, length in zip(num_padded_bits_per_window, original_lengths_per_window):
# 按照窗口大小从比特流中提取当前窗口的比特流
window_bitstream = bitstream[start:start + window_bit_size]
# 将比特流转换回字节流
window_compressed, _ = utils.bits_to_bytes(window_bitstream)
print(f"当前窗口压缩数据: {window_compressed}")
# 解码当前窗口
decoded_window = m1_arithmetic_decompress(
window_compressed,
predict_fn,
num_padded,
length,
)
print(f"解压缩窗口: {decoded_window}")
print(f"解压缩窗口的长度: {len(decoded_window)}")
decoded_bytes += decoded_window
# 更新起始位置
start += window_bit_size
return decoded_bytes
def test_equal_window_compression(sequence: str):
print(f"测试序列: {sequence}")
model, tokenizer, predict_fn = load_m1_model_and_tokenizer(consolidated_path = "/mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/m1_checkpoints/m1_1M_steps10k_bs32_seqlen2048/checkpoints/0000010000")
original_bytes = tokenizer.encode(sequence)
print(f"token 字节: {original_bytes}")
original_bytes = bytes(original_bytes)
print(f"转为字节流: {original_bytes}")
compressed_data = m1_arithmetic_compress_with_windows(
original_bytes,
predict_fn=predict_fn,
window_bit_size=WINDOW_SIZE,
return_num_padded_bits=True,
use_slow_lossless_compression=False
)
compressed_bytes, num_padded_bits_per_window, original_lengths_per_window = compressed_data
print(f"压缩后的字节: {compressed_bytes}")
print(f"窗口填充位数数组: {num_padded_bits_per_window}")
print(f"窗口原始长度数组: {original_lengths_per_window}")
try:
decoded_bytes = m1_arithmetic_decompress_with_windows(
compressed_bytes,
predict_fn,
WINDOW_SIZE,
num_padded_bits_per_window,
original_lengths_per_window
)
decoded_sequence = decoded_bytes.decode('utf-8')
assert decoded_sequence == sequence, f"解码失败:原始={sequence}, 解码={decoded_sequence}"
except:
print(f"解码失败")
compression_ratio = len(original_bytes) / len(compressed_bytes)
print(f"原始大小: {len(original_bytes)} bytes")
print(f"压缩后大小: {len(compressed_bytes)} bytes")
print(f"各窗口填充位数: {num_padded_bits_per_window}")
print(f"各窗口原始长度: {original_lengths_per_window}")
print(f"压缩率: {compression_ratio:.2f}x")
# print(f"解码结果: {decoded_sequence}\n")
if __name__ == "__main__":
test_sequences = [
"if month % 2 : return 30",
"def __init__(self, name): self.name = name",
"import pandas as pd\n import matplotlib.pyplot as plt",
"import torch",
"import torch\nimport torch.nn as nn\nimport torch.nn.functional as nn"
"import torch\nimport torch.nn as nn\nimport torch.nn.functional as F" * 100,
]
# test_sequences = [
# "hello world",
# "this is a test for arithmetic coding",
# "language modeling is compression - this is the core idea",
# b"".join([b'\x48' * 100]).decode('utf-8'), # 重复字节测试(H的ASCII码)
# ]
for seq in test_sequences:
print(f"=== m1语言模型 Equal_Window 算术压缩 ===")
start_time = time.time()
test_equal_window_compression(seq)
end_time = time.time()
print(f"用时: {end_time - start_time} 秒")