| |
| |
|
|
| import json |
| import base64 |
| import argparse |
| import os |
| import sys |
| import gzip |
| import math |
| import torch |
| import torch.multiprocessing as mp |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import pandas as pd |
| import Levenshtein |
| from typing import List, Callable, Tuple, Optional |
| from concurrent.futures import ThreadPoolExecutor |
|
|
| |
| |
| |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
| current_dir = os.getcwd() |
| if current_dir not in sys.path: |
| sys.path.append(current_dir) |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| if script_dir not in sys.path: |
| sys.path.append(script_dir) |
|
|
| |
| |
| |
| try: |
| from transformers import AutoTokenizer |
| except ImportError: |
| print("❌ Error: transformers not installed.") |
| sys.exit(1) |
|
|
| try: |
| from m1_compression import utils |
| from m1_compression.compressor import ( |
| load_m1_model_and_tokenizer, |
| ALPHABET_SIZE, |
| ARITHMETIC_CODER_BASE, |
| ARITHMETIC_CODER_PRECISION, |
| ) |
| from m1_compression.hybrid_arithmetic_coder import CPUArithmeticEncoder |
| from m1_compression.batched_arithmetic_coder import _pdf_to_cdf |
| except ImportError as e: |
| print(f"❌ Error: m1_compression not found. {e}") |
| sys.exit(1) |
|
|
| |
| |
| |
|
|
| def vread(buf: bytes, i: int): |
| shift = val = 0 |
| while True: |
| b = buf[i] |
| i += 1 |
| val |= (b & 0x7F) << shift |
| if b < 0x80: |
| return val, i |
| shift += 7 |
|
|
|
|
| def unpack_windows(input_bytes: bytes, b64_stream: str) -> List[bytes]: |
| """ |
| 只返回需要压缩的 windows bytes 段(忽略 gap 非压缩区域)。 |
| """ |
| try: |
| if not b64_stream: |
| return [] |
| buf = base64.b64decode(b64_stream) |
| i = 0 |
| cursor = 0 |
| segments: List[bytes] = [] |
| while i < len(buf): |
| gap, i = vread(buf, i) |
| size, i = vread(buf, i) |
| start = cursor + gap |
| end = start + size |
| if end > len(input_bytes): |
| break |
| segments.append(input_bytes[start:end]) |
| cursor = end |
| return segments |
| except Exception: |
| return [] |
|
|
|
|
| def token_ids_to_str(ids: List[int]) -> str: |
| |
| |
| return "".join(chr(x if x <= 0x10FFFF else 0x10FFFF) for x in ids) |
|
|
|
|
| def bytes_to_latin1_str(b: bytes) -> str: |
| |
| return b.decode("latin1") |
|
|
|
|
| def pad_batch_fast(batch: List[bytes]) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| 将 List[bytes] -> (padded_batch[int64], lengths[int64]) |
| 关键优化:numpy.frombuffer + 一次性拷贝,避免 Python list(data) |
| """ |
| if not batch: |
| return torch.empty((0, 0), dtype=torch.long), torch.empty((0,), dtype=torch.long) |
|
|
| lengths_np = np.fromiter((len(x) for x in batch), dtype=np.int32, count=len(batch)) |
| max_len = int(lengths_np.max()) |
| arr = np.zeros((len(batch), max_len), dtype=np.uint8) |
| for i, seg in enumerate(batch): |
| seg_np = np.frombuffer(seg, dtype=np.uint8) |
| if seg_np.size: |
| arr[i, : seg_np.size] = seg_np |
| padded = torch.from_numpy(arr).to(torch.long) |
| lengths = torch.from_numpy(lengths_np.astype(np.int64)) |
| return padded, lengths |
|
|
|
|
| def iter_jsonl_shard_bytes(file_path: str, shard_rank: int, shard_world: int): |
| """ |
| 按“字节范围”切分 jsonl 文件:每个 shard 只读自己负责的文件区间。 |
| 适合“单文件吃满多 GPU”。 |
| """ |
| file_size = os.path.getsize(file_path) |
| start = (file_size * shard_rank) // shard_world |
| end = (file_size * (shard_rank + 1)) // shard_world |
|
|
| with open(file_path, "rb") as f: |
| f.seek(start) |
| if start > 0: |
| f.readline() |
|
|
| while f.tell() < end: |
| line = f.readline() |
| if not line: |
| break |
| yield line |
|
|
|
|
| |
| |
| |
|
|
| def batched_m1_compress_predict_fn(model): |
| def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor: |
| if input_tensor.dim() == 1: |
| input_tensor = input_tensor.unsqueeze(0) |
| with torch.inference_mode(): |
| logits = model(input_tensor, **kwargs) |
| logits = logits[..., :256].float() |
| probs = torch.softmax(logits, dim=-1) |
| return probs |
| return predict_fn |
|
|
|
|
| def compress_segments_smart_batch_bytes( |
| all_segments: List[bytes], |
| batched_predict_fn: Callable, |
| first_byte_prob: torch.Tensor, |
| device: torch.device, |
| encoder: CPUArithmeticEncoder, |
| gpu_batch_size: int = 256, |
| bit_threshold: int = 64, |
| ) -> List[bytes]: |
| """ |
| 高性能 AC 压缩: |
| 1) 先按长度排序,降低 padding 浪费 |
| 2) 推理在 GPU,编码在 CPU |
| 3) 输出每个 segment 的压缩 bytes(不转 List[int]) |
| """ |
| M = len(all_segments) |
| if M == 0: |
| return [] |
|
|
| lengths = np.fromiter((len(s) for s in all_segments), dtype=np.int32, count=M) |
| sorted_indices = np.argsort(lengths, kind="stable") |
| sorted_segments = [all_segments[i] for i in sorted_indices] |
|
|
| out: List[Optional[bytes]] = [None] * M |
|
|
| for i in range(0, M, gpu_batch_size): |
| batch_slice = sorted_segments[i : i + gpu_batch_size] |
| batch_orig_indices = sorted_indices[i : i + gpu_batch_size] |
|
|
| try: |
| padded_batch_cpu, lengths_cpu = pad_batch_fast(batch_slice) |
|
|
| |
| padded_batch = padded_batch_cpu.pin_memory().to(device, non_blocking=True) |
|
|
| |
| prompt_probs = batched_predict_fn(padded_batch) |
|
|
| final_probs = torch.cat( |
| [ |
| first_byte_prob.expand(prompt_probs.shape[0], -1, -1), |
| prompt_probs[:, :-1, ...], |
| ], |
| dim=1, |
| ) |
|
|
| |
| final_probs = utils.batched_normalize_pdf_for_arithmetic_coding(final_probs) |
| cdfs_gpu = _pdf_to_cdf(final_probs) |
|
|
| cdf_low = cdfs_gpu.gather(2, padded_batch.unsqueeze(-1)).squeeze(-1) |
| cdf_high = cdfs_gpu.gather(2, (padded_batch + 1).unsqueeze(-1)).squeeze(-1) |
| cdf_ends = torch.stack([cdf_low, cdf_high], dim=-1) |
|
|
| |
| enc_out = encoder.incremental_batched_encode( |
| cdf_ends.cpu(), |
| ALPHABET_SIZE, |
| lengths_cpu, |
| bit_threshold=bit_threshold, |
| force_padding_to_threshold=False, |
| return_num_padded_bits=False, |
| ) |
|
|
| if isinstance(enc_out, tuple): |
| chunked_compressed_bytes = enc_out[0] |
| else: |
| chunked_compressed_bytes = enc_out |
|
|
| for idx, code in zip(batch_orig_indices, chunked_compressed_bytes): |
| out[int(idx)] = bytes(code) |
|
|
| except Exception: |
| |
| for idx, seg in zip(batch_orig_indices, batch_slice): |
| out[int(idx)] = seg |
|
|
| return [x if x is not None else b"" for x in out] |
|
|
|
|
| class M1ACManager: |
| def __init__(self, model_path: str, first_prob_path: str, device_id: int, |
| gpu_batch_size: int = 256, bit_threshold: int = 64): |
| self.device = torch.device(f"cuda:{device_id}") |
| self.gpu_batch_size = gpu_batch_size |
| self.bit_threshold = bit_threshold |
|
|
| self.model, _, _ = load_m1_model_and_tokenizer(model_path) |
| self.model.to(self.device) |
| self.model.eval() |
| self.predict_fn = batched_m1_compress_predict_fn(self.model) |
|
|
| if first_prob_path and os.path.exists(first_prob_path): |
| with open(first_prob_path, "r") as f: |
| prob_data = json.load(f) |
| self.first_byte_prob = torch.tensor(prob_data, dtype=torch.float32, device=self.device) |
| if self.first_byte_prob.dim() == 1: |
| self.first_byte_prob = self.first_byte_prob.unsqueeze(0).unsqueeze(0) |
| else: |
| self.first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device=self.device) / ALPHABET_SIZE |
|
|
| |
| self.encoder = CPUArithmeticEncoder(base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION) |
|
|
| def compress_batch_smart_bytes(self, inputs: List[Tuple[str, Optional[str]]]) -> List[bytes]: |
| """ |
| inputs: List[(text, windows_b64_or_None)] |
| Return: List[bytes] 每个 sample 对应拼接后的 AC bitstream(bytes) |
| """ |
| all_segments_flat: List[bytes] = [] |
| sample_map: List[Tuple[int, int]] = [] |
| current_idx = 0 |
|
|
| for text, windows_b64 in inputs: |
| raw_bytes = text.encode("utf-8") |
| sample_segs: List[bytes] = [] |
|
|
| if windows_b64: |
| sample_segs = unpack_windows(raw_bytes, windows_b64) |
| else: |
| CHUNK = 512 |
| for j in range(0, len(raw_bytes), CHUNK): |
| sample_segs.append(raw_bytes[j : j + CHUNK]) |
|
|
| count = len(sample_segs) |
| sample_map.append((current_idx, current_idx + count)) |
| all_segments_flat.extend(sample_segs) |
| current_idx += count |
|
|
| if not all_segments_flat: |
| return [b"" for _ in inputs] |
|
|
| compressed_chunks_flat = compress_segments_smart_batch_bytes( |
| all_segments_flat, |
| self.predict_fn, |
| self.first_byte_prob, |
| self.device, |
| self.encoder, |
| gpu_batch_size=self.gpu_batch_size, |
| bit_threshold=self.bit_threshold, |
| ) |
|
|
| results: List[bytes] = [] |
| for start, end in sample_map: |
| results.append(b"".join(compressed_chunks_flat[start:end])) |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def run_gzip_task(text_pair: Tuple[str, str]) -> float: |
| t1, t2 = text_pair |
| b1 = t1.encode("utf-8") |
| b2 = t2.encode("utf-8") |
| g1 = gzip.compress(b1) |
| g2 = gzip.compress(b2) |
| if not g1: |
| return 0.0 |
| d = Levenshtein.distance(bytes_to_latin1_str(g1), bytes_to_latin1_str(g2)) |
| return d / len(g1) |
|
|
|
|
| def process_one_file( |
| gpu_id: int, |
| file_path: str, |
| tokenizer: AutoTokenizer, |
| ac_manager: M1ACManager, |
| max_lines: int, |
| worker_batch_size: int, |
| gzip_threads: int, |
| shard_rank: int, |
| shard_world: int, |
| ) -> dict: |
| """ |
| 处理单个 jsonl 文件,返回 results dict |
| """ |
| results = {"Gzip": [], "Tokenizer": [], "AC_M1": []} |
|
|
| |
| if shard_world > 1 and max_lines > 0: |
| shard_max_lines = int(math.ceil(max_lines / shard_world)) |
| else: |
| shard_max_lines = max_lines |
|
|
| raw_texts: List[str] = [] |
| pert_texts: List[str] = [] |
| metas: List[Optional[str]] = [] |
|
|
| processed_total = 0 |
|
|
| |
| thread_pool = ThreadPoolExecutor(max_workers=gzip_threads) |
|
|
| def flush(): |
| nonlocal raw_texts, pert_texts, metas |
| if not raw_texts: |
| return |
|
|
| curr_batch_size = len(raw_texts) |
|
|
| |
| gz_vals = list(thread_pool.map(run_gzip_task, zip(raw_texts, pert_texts))) |
| results["Gzip"].extend(gz_vals) |
|
|
| |
| try: |
| tok1 = tokenizer(raw_texts, add_special_tokens=False)["input_ids"] |
| tok2 = tokenizer(pert_texts, add_special_tokens=False)["input_ids"] |
| for a, b in zip(tok1, tok2): |
| if a: |
| d = Levenshtein.distance(token_ids_to_str(a), token_ids_to_str(b)) |
| results["Tokenizer"].append(d / len(a)) |
| except Exception: |
| |
| pass |
|
|
| |
| orig_inputs = list(zip(raw_texts, metas)) |
| pert_inputs = list(zip(pert_texts, [None] * curr_batch_size)) |
| both_inputs = orig_inputs + pert_inputs |
|
|
| try: |
| both_streams = ac_manager.compress_batch_smart_bytes(both_inputs) |
| ac1_list = both_streams[:curr_batch_size] |
| ac2_list = both_streams[curr_batch_size:] |
|
|
| for a1, a2 in zip(ac1_list, ac2_list): |
| if a1: |
| d = Levenshtein.distance(bytes_to_latin1_str(a1), bytes_to_latin1_str(a2)) |
| results["AC_M1"].append(d / len(a1)) |
| except Exception as e: |
| print(f"[GPU {gpu_id}] AC Batch Error: {e}") |
|
|
| raw_texts, pert_texts, metas = [], [], [] |
|
|
| |
| line_iter = iter_jsonl_shard_bytes(file_path, shard_rank, shard_world) |
|
|
| for i, line in enumerate(line_iter): |
| if shard_max_lines > 0 and i >= shard_max_lines: |
| break |
| try: |
| if not line or len(line) < 100: |
| continue |
| data = json.loads(line) |
| text = data.get("text", "") |
| if not isinstance(text, str) or len(text) < 50: |
| continue |
|
|
| windows = data.get("windows_starts_lens_b64") |
|
|
| cut_idx = max(1, int(len(text) * 0.2)) |
| |
| text_p = text[cut_idx:] |
|
|
| raw_texts.append(text) |
| pert_texts.append(text_p) |
| metas.append(windows) |
|
|
| processed_total += 1 |
| if len(raw_texts) >= worker_batch_size: |
| flush() |
|
|
| if processed_total % 2000 == 0: |
| print(f"[GPU {gpu_id}] processed {processed_total} lines (file={os.path.basename(file_path)}, shard={shard_rank}/{shard_world})") |
|
|
| except Exception: |
| continue |
|
|
| flush() |
| thread_pool.shutdown(wait=True) |
|
|
| print(f"[GPU {gpu_id}] done file={os.path.basename(file_path)} shard={shard_rank}/{shard_world} total={processed_total}") |
| return results |
|
|
|
|
| def process_files_worker( |
| rank: int, |
| gpu_id: int, |
| file_paths: List[str], |
| output_dir: str, |
| model_path: str, |
| prob_path: str, |
| max_lines: int, |
| worker_batch_size: int, |
| gzip_threads: int, |
| shard_mode: bool, |
| gpu_batch_size: int, |
| bit_threshold: int, |
| ): |
| """ |
| 一个 GPU 进程:加载一次 tokenizer + M1 模型,然后顺序处理分配给它的文件(或单文件 shard) |
| """ |
| try: |
| torch.cuda.set_device(gpu_id) |
|
|
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained( |
| "infly/OpenCoder-1.5B-Base", |
| trust_remote_code=True, |
| use_fast=True, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| except Exception: |
| tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True) |
|
|
| |
| ac_manager = M1ACManager( |
| model_path=model_path, |
| first_prob_path=prob_path, |
| device_id=gpu_id, |
| gpu_batch_size=gpu_batch_size, |
| bit_threshold=bit_threshold, |
| ) |
|
|
| for fp in file_paths: |
| base = os.path.basename(fp) |
|
|
| |
| if shard_mode: |
| shard_rank = rank |
| shard_world = torch.cuda.device_count() |
| else: |
| shard_rank = 0 |
| shard_world = 1 |
|
|
| print(f"[GPU {gpu_id}] start file={base} shard={shard_rank}/{shard_world}") |
|
|
| res = process_one_file( |
| gpu_id=gpu_id, |
| file_path=fp, |
| tokenizer=tokenizer, |
| ac_manager=ac_manager, |
| max_lines=max_lines, |
| worker_batch_size=worker_batch_size, |
| gzip_threads=gzip_threads, |
| shard_rank=shard_rank, |
| shard_world=shard_world, |
| ) |
|
|
| out_name = f"res_gpu{gpu_id}_rank{rank}_shard{shard_rank}of{shard_world}_{base}.json" |
| out_path = os.path.join(output_dir, out_name) |
| with open(out_path, "w") as f: |
| json.dump(res, f) |
|
|
| except Exception as e: |
| print(f"❌ [GPU {gpu_id}] Worker Error: {e}") |
| import traceback |
| traceback.print_exc() |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input_dir", type=str, required=True) |
| parser.add_argument("--m1_model", type=str, required=True) |
| parser.add_argument("--first_prob_path", type=str, required=True) |
| parser.add_argument("-o", "--output_dir", type=str, default="analysis_output_fast_opt") |
| parser.add_argument("--max_lines", type=int, default=10000) |
|
|
| |
| parser.add_argument("--max_files", type=int, default=8, help="只取前 N 个 jsonl 文件;0 表示不限制") |
| parser.add_argument("--worker_batch_size", type=int, default=500, help="flush 的行数 batch") |
| parser.add_argument("--gzip_threads", type=int, default=8, help="每个 GPU 进程内用于 gzip 的线程数") |
| parser.add_argument("--ac_gpu_batch_size", type=int, default=256, help="AC 推理的 GPU mini-batch size") |
| parser.add_argument("--ac_bit_threshold", type=int, default=64, help="Arithmetic coder bit_threshold(16->64/128 往往更快)") |
| args = parser.parse_args() |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| files = [ |
| os.path.join(args.input_dir, f) |
| for f in os.listdir(args.input_dir) |
| if f.endswith(".jsonl") and "writer" not in f |
| ] |
| files.sort() |
|
|
| if args.max_files and args.max_files > 0: |
| files = files[: args.max_files] |
|
|
| num_gpus = torch.cuda.device_count() |
| if num_gpus == 0: |
| print("❌ No GPU detected.") |
| return |
| if not files: |
| print("❌ No jsonl files found.") |
| return |
|
|
| |
| |
| |
| shard_mode = (len(files) == 1 and num_gpus > 1) |
|
|
| assignments: List[List[str]] = [[] for _ in range(num_gpus)] |
| if shard_mode: |
| for r in range(num_gpus): |
| assignments[r] = [files[0]] |
| print(f"🚀 Single-file shard mode enabled: {files[0]} -> {num_gpus} shards") |
| else: |
| for idx, fp in enumerate(files): |
| assignments[idx % num_gpus].append(fp) |
| non_empty = sum(1 for a in assignments if a) |
| print(f"🚀 Multi-file mode: {len(files)} files assigned across {non_empty}/{num_gpus} GPU workers") |
|
|
| print(f" worker_batch_size={args.worker_batch_size}, gzip_threads={args.gzip_threads}, ac_gpu_batch_size={args.ac_gpu_batch_size}, ac_bit_threshold={args.ac_bit_threshold}") |
|
|
| mp.set_start_method("spawn", force=True) |
| procs = [] |
| for rank in range(num_gpus): |
| if not assignments[rank]: |
| continue |
| gpu_id = rank % num_gpus |
| p = mp.Process( |
| target=process_files_worker, |
| args=( |
| rank, |
| gpu_id, |
| assignments[rank], |
| args.output_dir, |
| args.m1_model, |
| args.first_prob_path, |
| args.max_lines, |
| args.worker_batch_size, |
| args.gzip_threads, |
| shard_mode, |
| args.ac_gpu_batch_size, |
| args.ac_bit_threshold, |
| ), |
| ) |
| p.start() |
| procs.append(p) |
|
|
| for p in procs: |
| p.join() |
|
|
| |
| print("✅ Merging results...") |
| final_results = {"Gzip": [], "Tokenizer": [], "AC_M1": []} |
| for fn in os.listdir(args.output_dir): |
| if fn.startswith("res_") and fn.endswith(".json"): |
| try: |
| with open(os.path.join(args.output_dir, fn), "r") as f: |
| d = json.load(f) |
| for k in final_results: |
| final_results[k].extend(d.get(k, [])) |
| except Exception: |
| pass |
|
|
| for k, v in final_results.items(): |
| print(f" {k}: {len(v)} samples") |
|
|
| |
| stats = {} |
| for k, v in final_results.items(): |
| if v: |
| stats[k] = {"count": int(len(v)), "mean": float(np.mean(v)), "p50": float(np.median(v))} |
| with open(os.path.join(args.output_dir, "final_stats.json"), "w") as f: |
| json.dump(stats, f, indent=2) |
| print(f"📄 Saved stats -> {os.path.join(args.output_dir, 'final_stats.json')}") |
|
|
| |
| plot_data = [] |
| for algo, vals in final_results.items(): |
| for val in vals: |
| if val < 2.0: |
| plot_data.append({"Algorithm": algo, "NED": val}) |
|
|
| if plot_data: |
| df = pd.DataFrame(plot_data) |
| plt.figure(figsize=(10, 6)) |
| sns.kdeplot(data=df, x="NED", hue="Algorithm", fill=True, common_norm=False) |
| plt.title("Compression Stability (Optimized)") |
| plt.xlabel("Normalized Levenshtein Distance") |
| plt.xlim(0, 1.2) |
| out_img = os.path.join(args.output_dir, "stability_fast_opt.png") |
| plt.savefig(out_img, dpi=200) |
| print(f"📊 Saved plot -> {out_img}") |
|
|
| print("🎉 Done.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|
| """ |
| python fast_compare.py \ |
| --input_dir /mnt/hdfs/user/linzheng/data/ocpython_subsampled_50G_entropy90_splits_chunk512_ow20_iterative-true_forcepadding-true_merged_ac \ |
| --m1_model /mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/m1_checkpoints/m1_40M_lr1e-3_steps200k_bs8_seqlen2048_python/checkpoints/0000200000 \ |
| --first_prob_path /mnt/bn/tiktok-mm-5/aiic/users/linzheng/artifacts/ac_unigram_probs/python500k_unigram_prob.json \ |
| --max_lines 10000 \ |
| -o analysis_output_fast_opt \ |
| --max_files 8 |
| """ |