File size: 16,052 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
import numpy as np
from typing import Tuple, Callable
from m1_compression import arithmetic_coder
from m1_compression import utils
import torch
import logging
from pathlib import Path
import time
from apps.main.transformer import LMTransformer, LMTransformerArgs
from lingua.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
from apps.main.generate import (
    load_consolidated_model_and_tokenizer,
)
from lingua.args import dataclass_from_dict
from lingua.checkpoint import CONSOLIDATE_NAME
from omegaconf import OmegaConf
logger = logging.getLogger()

ALPHABET_SIZE = 256
# Base 2 means that the coder writes bits.
ARITHMETIC_CODER_BASE = 2
# Precision 16 implies 16 bit arithmetic, in the original paper it is 32.
ARITHMETIC_CODER_PRECISION = 32
WINDOW_SIZE = 32  # 窗口大小,以位为单位,可根据需求调整

def load_m1_model_and_tokenizer(consolidated_path: str):
    """
    Args:
        consolidated_path (str): 模型检查点的路径。
    """
    # 加载配置文件
    ckpt_dir = Path(consolidated_path)
    if (
        Path(ckpt_dir).exists()
        and (Path(ckpt_dir) / "params.json").exists()
        and next(Path(ckpt_dir).glob("*.pth"), None) is not None
    ):
        consolidate_path = Path(ckpt_dir)
    else:
        consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
        if not consolidate_path.exists():
            consolidate_path = consolidate_checkpoints(ckpt_dir)

    # use api to load model
    consolidate_path = str(consolidate_path)
    logger.info("Loading model")
    model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
        consolidate_path,
        model_cls=LMTransformer,
        model_args_cls=LMTransformerArgs,
    )
    logger.info("Model loaded")
    model.eval()
    predict_fn = get_predict_fn(model, tokenizer)
    return model, tokenizer, predict_fn

def load_m1_model_cpu(consolidated_path: str):
    """
    Args:
        consolidated_path (str): 模型检查点的路径。
    """
    # 加载配置文件
    ckpt_dir = Path(consolidated_path)
    if (
        Path(ckpt_dir).exists()
        and (Path(ckpt_dir) / "params.json").exists()
        and next(Path(ckpt_dir).glob("*.pth"), None) is not None
    ):
        consolidate_path = Path(ckpt_dir)
    else:
        consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
        if not consolidate_path.exists():
            consolidate_path = consolidate_checkpoints(ckpt_dir)

    # use api to load model
    consolidate_path = str(consolidate_path)
    logger.info("Loading model")
    ckpt_path = Path(consolidate_path)
    config = ckpt_path / "params.json"
    config = OmegaConf.load(config)

    param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
        config.distributed.model_dtype
    ]
    model_args = dataclass_from_dict(LMTransformerArgs, config.model, strict=False)
    model = LMTransformer(model_args)
    st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
    model.load_state_dict(st_dict["model"])
    model = model.eval()
    for param in model.parameters():
        param.data = param.data.to(dtype=param_dtype)
    logger.info("Model loaded")
    model.eval()
    return model

def get_predict_fn(model, tokenizer):
    """
    return a function that takes a sequence of tokens and returns the probability distribution of the next token.
    Args:
        model: the model to use for prediction
        tokenizer: the tokenizer to use for encoding/decoding
    """
    def predict_fn(input_sequence: np.ndarray) -> np.ndarray:
        """
        Args:
            input_sequence (np.ndarray): 输入序列,形状为 (batch_size, seq_len)。
        Returns:
            np.ndarray: 每个 token 的概率分布,形状为 (batch_size, seq_len, vocab_size)。
        """

        # 如果输入的序列为空,返回模型平均概率分布---压缩和解压我必须要看见这个
        if input_sequence.size == 0:
            initial_probs = np.ones((1, 1, 256), dtype=np.float32) / 256  # 均匀分布
            return initial_probs

        # turn to torch tensor
        input_tensor = torch.tensor(input_sequence, dtype=torch.long).cuda()
        if input_tensor.dim() == 1:
            input_tensor = input_tensor.unsqueeze(0)
        with torch.no_grad():
            # get logits
            logits = model(input_tensor)

        logits = logits[..., :256]
        logits = logits.float()
        assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values."
        probs = torch.softmax(logits, dim=-1)
        probs = probs.float().cpu().numpy()

        return probs
    
    return predict_fn


def m1_arithmetic_compress(
    data: bytes,
    predict_fn: Callable,
    return_num_padded_bits: bool = True,
    use_slow_lossless_compression: bool = True
) -> bytes | tuple[bytes, int]:
    """use language model to compress, return compressed bytes and padded bits"""
    sequence_array = np.frombuffer(data, dtype=np.uint8)

    if use_slow_lossless_compression:
        probs = []
        for k in range(len(sequence_array)):  # k 表示已处理的 token 数(从 0 到 n-1)
            # 输入前 k 个 token(初始 k=0 时为空序列),预测第 k+1 个 token 的概率
            input_seq = sequence_array[:k]  # 前 k 个 token(空序列当 k=0 时)
            if input_seq.size == 0:
                # 空输入时使用均匀分布(或模型预设的初始分布)
                current_probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE
            else:
                model_probs = predict_fn(input_seq[None])  # 形状 (1, k, 256)
                current_probs = model_probs[0, -1]  # 提取最后一个位置(第 k 个位置,预测第 k+1 个 token)
            probs.append(current_probs)
        probs = np.array(probs)
    else:
        if sequence_array.size == 0:
            probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE
        else:
            full_probs = predict_fn(sequence_array[None])[0, ...]
            probs = np.concatenate(
                [
                    np.ones(ALPHABET_SIZE, dtype=np.float32)[None] / ALPHABET_SIZE, 
                    full_probs[:-1,...]
                ], axis=0
            )
        # print("[FAST] shape : {}".format(probs.shape))
    probs /= probs.sum(axis=-1, keepdims=True)

    assert probs.shape[1] == ALPHABET_SIZE, "The shape of probs is not correct."
    assert probs.shape[0] == len(sequence_array), "The shape of probs is not correct."
    assert np.isclose(sum(probs[0]), 1, atol=1e-6), "The probs is not normalized."

    output = []
    encoder = arithmetic_coder.Encoder(
        base=ARITHMETIC_CODER_BASE,
        precision=ARITHMETIC_CODER_PRECISION,
        output_fn=output.append
    )
    for pdf, symbol in zip(probs, sequence_array):
        encoder.encode(utils.normalize_pdf_for_arithmetic_coding(pdf), symbol)
    encoder.terminate()

    compressed_bits = ''.join(map(str, output))
    ## padding zero to turn the bitstream into bytes
    compressed_bytes, num_padded_bits = utils.bits_to_bytes(compressed_bits)

    if return_num_padded_bits:
        return compressed_bytes, num_padded_bits
    return compressed_bytes
    
def m1_arithmetic_decompress(
    compressed: bytes, 
    predict_fn: Callable,
    num_padded_bits: int, 
    length: int,
) -> np.ndarray:
    bits = utils.bytes_to_bits(compressed, num_padded_bits=num_padded_bits)
    data_iter = iter(bits)
    def _input_fn() -> int | None:
        try:
            return int(next(data_iter))
        except StopIteration:
            return None
    
    decoder = arithmetic_coder.Decoder(
        base=ARITHMETIC_CODER_BASE,
        precision=ARITHMETIC_CODER_PRECISION,
        input_fn=_input_fn
    )
    sequence_array = np.empty((0,), dtype=np.uint8)
    for k in range(length):
        # 预测第 k+1 个 token 的概率:输入已解码的 k 个 token
        if k == 0:
            # 第一个 token:输入空序列,使用初始分布
            input_seq = np.empty((0,), dtype=np.uint8)
            current_probs = np.ones(ALPHABET_SIZE, dtype=np.float32) / ALPHABET_SIZE
        else:
            input_seq = sequence_array  # 已解码的 k 个 token
            model_probs = predict_fn(input_seq[None])  # 形状 (1, k, 256)
            current_probs = model_probs[0, -1]  # 提取最后一个位置的概率(下一个 token)

        # 确保概率归一化(如果模型输出未归一化)
        current_probs /= current_probs.sum()
        #print("------ this range is ------",k)
        #print("------current_probs-----------:", current_probs)
        token = decoder.decode(utils.normalize_pdf_for_arithmetic_coding(current_probs))
        ## 直接append 出现了中间吧sequence_array类型扩大的情况,出现了前置很多0
        #sequence_array = np.append(sequence_array, token)
        sequence_array = np.concatenate([sequence_array, np.array([token], dtype=np.uint8)])
    # print("解压缩后的 token IDs:", sequence_array)
    # print("to byte is:", sequence_array.tobytes())
    return sequence_array.tobytes()

def m1_arithmetic_compress_with_windows(
    data: bytes,
    predict_fn: Callable,
    window_bit_size=WINDOW_SIZE,
    return_num_padded_bits: bool = True,
    use_slow_lossless_compression: bool = True
) -> bytes | Tuple[bytes, list[int], list[int]]:
    compressed_windows = []
    num_padded_bits_per_window = []
    original_lengths_per_window = []
    current_window = []

    for byte in data:
        # 压缩窗口中加入此字节
        current_window.append(byte)
        #进行压缩,得到压缩后的字节和填充位数
        compressed, num_padded = m1_arithmetic_compress(
            bytes(current_window), 
            predict_fn=predict_fn,
            return_num_padded_bits=True, 
            use_slow_lossless_compression=use_slow_lossless_compression,
        )
        #计算压缩后的bitstream
        compressed_bits = ''.join(map(str, utils.bytes_to_bits(compressed, num_padded)))
        # 计算是否超过窗口大小,进行截断
        if len(compressed_bits) > window_bit_size:
            # print(f"oversize当前窗口压缩后的位数: {len(compressed_bits)}")
            current_window.pop()
            #得到上一个切分的窗口进行窗口内压缩并且保存
            compressed, num_padded = m1_arithmetic_compress(
                bytes(current_window), 
                predict_fn=predict_fn,
                return_num_padded_bits=True, 
                use_slow_lossless_compression=use_slow_lossless_compression,
            )
            compressed_bits = ''.join(map(str, utils.bytes_to_bits(compressed, num_padded)))
            compressed_windows.append(compressed)
            num_padded_bits_per_window.append(num_padded)
            original_lengths_per_window.append(len(current_window))
            # print(f"当前窗口: {current_window}")
            
            # print(f"当前窗口压缩长度: {len(compressed)}")
            # print(f"当前窗口压缩后的位数: {len(compressed_bits)}")
            # print(f"填充长度: {num_padded}")
            # print(f"----------")
            # print(f"当前compressed byte数据: {compressed}")
            # print(f"当前bitstream: {compressed_bits}")
            current_window = [byte]

    # 处理最后一个窗口
    if current_window:
        compressed, num_padded = m1_arithmetic_compress(
            bytes(current_window), 
            predict_fn=predict_fn,
            return_num_padded_bits=True, 
            use_slow_lossless_compression=use_slow_lossless_compression,
        )
        compressed_windows.append(compressed)
        num_padded_bits_per_window.append(num_padded)
        original_lengths_per_window.append(len(current_window))

    all_compressed_bytes = b''.join(compressed_windows)

    if return_num_padded_bits:
        return all_compressed_bytes, num_padded_bits_per_window, original_lengths_per_window
    return all_compressed_bytes


def m1_arithmetic_decompress_with_windows(
    compressed: bytes,
    predict_fn: Callable,
    window_bit_size,
    num_padded_bits_per_window: list[int],
    original_lengths_per_window: list[int]
) -> bytes:
    decoded_bytes = b''
    start = 0
    bitstream = utils.bytes_to_bits(compressed)
    for num_padded, length in zip(num_padded_bits_per_window, original_lengths_per_window):
        # 按照窗口大小从比特流中提取当前窗口的比特流
        window_bitstream = bitstream[start:start + window_bit_size]
        # 将比特流转换回字节流
        window_compressed, _ = utils.bits_to_bytes(window_bitstream)
        print(f"当前窗口压缩数据: {window_compressed}")
        # 解码当前窗口
        decoded_window = m1_arithmetic_decompress(
            window_compressed, 
            predict_fn,
            num_padded, 
            length,
        )
        print(f"解压缩窗口: {decoded_window}")
        print(f"解压缩窗口的长度: {len(decoded_window)}")
        decoded_bytes += decoded_window
        # 更新起始位置
        start += window_bit_size
    return decoded_bytes


def test_equal_window_compression(sequence: str):
    print(f"测试序列: {sequence}")
    model, tokenizer, predict_fn = load_m1_model_and_tokenizer(consolidated_path = "/mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/m1_checkpoints/m1_1M_steps10k_bs32_seqlen2048/checkpoints/0000010000")
    original_bytes = tokenizer.encode(sequence)
    print(f"token 字节: {original_bytes}")
    original_bytes = bytes(original_bytes)
    print(f"转为字节流: {original_bytes}")
    compressed_data = m1_arithmetic_compress_with_windows(
        original_bytes,
        predict_fn=predict_fn,
        window_bit_size=WINDOW_SIZE,
        return_num_padded_bits=True,
        use_slow_lossless_compression=False
    )
    compressed_bytes, num_padded_bits_per_window, original_lengths_per_window = compressed_data
    print(f"压缩后的字节: {compressed_bytes}")
    print(f"窗口填充位数数组: {num_padded_bits_per_window}")
    print(f"窗口原始长度数组: {original_lengths_per_window}")
    try:
        decoded_bytes = m1_arithmetic_decompress_with_windows(
            compressed_bytes, 
            predict_fn,
            WINDOW_SIZE,
            num_padded_bits_per_window, 
            original_lengths_per_window
        )
        decoded_sequence = decoded_bytes.decode('utf-8')

        assert decoded_sequence == sequence, f"解码失败:原始={sequence}, 解码={decoded_sequence}"
    except:
        print(f"解码失败")

    compression_ratio = len(original_bytes) / len(compressed_bytes)
    print(f"原始大小: {len(original_bytes)} bytes")
    print(f"压缩后大小: {len(compressed_bytes)} bytes")
    print(f"各窗口填充位数: {num_padded_bits_per_window}")
    print(f"各窗口原始长度: {original_lengths_per_window}")
    print(f"压缩率: {compression_ratio:.2f}x")
    # print(f"解码结果: {decoded_sequence}\n")


if __name__ == "__main__":
    test_sequences = [
        "if month % 2 : return 30",
        "def __init__(self, name): self.name = name",
        "import pandas as pd\n  import matplotlib.pyplot as plt",
        "import torch",
        "import torch\nimport torch.nn as nn\nimport torch.nn.functional as nn"
        "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F" * 100,
    ]
    # test_sequences = [
    #     "hello world",
    #     "this is a test for arithmetic coding",
    #     "language modeling is compression - this is the core idea",
    #     b"".join([b'\x48' * 100]).decode('utf-8'), # 重复字节测试(H的ASCII码)
    # ]
    for seq in test_sequences:
        print(f"=== m1语言模型 Equal_Window 算术压缩 ===")
        start_time = time.time()
        test_equal_window_compression(seq)
        end_time = time.time()
        print(f"用时: {end_time - start_time} 秒")