| import json |
| import base64 |
| import argparse |
| import os |
| from collections import defaultdict |
| from itertools import combinations |
| import re |
| from tqdm import tqdm |
| from typing import List, Dict, Any, Tuple |
| import argparse |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| try: |
| import Levenshtein |
| except ImportError: |
| print("❌ errord: 'python-Levenshtein' have not be installed") |
| print("run:pip install python-Levenshtein") |
| exit(1) |
|
|
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import numpy as np |
|
|
|
|
| def vread(buf: bytes, i: int): |
| shift = val = 0 |
| while True: |
| b = buf[i] |
| i += 1 |
| val |= (b & 0x7F) << shift |
| if b < 0x80: |
| return val, i |
| shift += 7 |
|
|
| def unpack_windows(input_bytes: bytes, b64_stream: str) -> List[Tuple[bytes, int]]: |
| try: |
| buf, i, cursor, byte_windows = base64.b64decode(b64_stream), 0, 0, [] |
| while i < len(buf): |
| gap, i = vread(buf, i) |
| size, i = vread(buf, i) |
| start = cursor + gap |
| if gap > 0: byte_windows.append((input_bytes[cursor:start], 0)) |
| end = start + size |
| byte_windows.append((input_bytes[start:end], 1)) |
| cursor = end |
| if cursor < len(input_bytes): byte_windows.append((input_bytes[cursor:], 0)) |
| return byte_windows |
| except (base64.binascii.Error, IndexError): return [] |
|
|
|
|
| def packed_bytes_to_pseudo(b: bytes) -> list[int]: |
| out, acc, bits = [], 0, 0 |
| for byte in b: |
| acc |= byte << bits |
| bits += 8 |
| while bits >= 9: |
| out.append(acc & 0x1FF) |
| acc >>= 9 |
| bits -= 9 |
| return out |
|
|
| |
| def decompress_windows_starts_lens(b64_stream: str) -> tuple[list[int], list[int]]: |
| try: |
| buf = base64.b64decode(b64_stream) |
| i = 0 |
| cursor= 0 |
| starts, lens = [], [] |
| while i < len(buf): |
| gap, i = vread(buf, i) |
| size, i = vread(buf, i) |
| start = cursor + gap |
| length = size |
| starts.append(start) |
| lens.append(length) |
| cursor = start + length |
| return starts, lens |
| except (base64.binascii.Error, IndexError): |
| |
| return [], [] |
|
|
|
|
| def parse_parameters_from_path(file_path: str) -> dict: |
| """parse parameters""" |
| params = {} |
| base_name = os.path.basename(os.path.normpath(file_path)) |
| parts = base_name.split('_') |
| for part in parts: |
| if '-' in part: |
| |
| key, value = part.split('-', 1) |
| params[key.lower()] = value.lower() |
| else: |
| |
| |
| match = re.match(r'([a-zA-Z]+)(\d+)', part) |
| if match: |
| key, value = match.groups() |
| params[key.lower()] = value |
| |
| params['bits_per_compressed'] = 10 |
| print(f"From path '{base_name}' to parse params: {params}") |
| return params |
|
|
|
|
| def construct_compression_key(params: dict) -> str: |
| """construct compressed data key。""" |
| ow = params.get('ow', 20) |
| escape_fb = 'True' if params.get('escapefb', 'false') == 'true' else 'False' |
| iterative = 'True' if params.get('iterative', 'true') == 'true' else 'False' |
| force_padding = 'True' if params.get('forcepadding', 'false') == 'true' else 'False' |
| key = f"m1_ac_ow{ow}_escapefb-{escape_fb}_iterative-{iterative}_forcepadding-{force_padding}" |
| print(f"Compress data Key is: '{key}'") |
| return key |
|
|
|
|
| def analyze_token_collisions_in_directory(input_dir: str, |
| output_dir: str, |
| compression_offset: int = 256, |
| max_files: int = -1, |
| max_lines: int = -1): |
| """for dir token-level collusion """ |
| if not os.path.isdir(input_dir): |
| print(f"❌ Error: input file is not valid '{input_dir}'"); return |
| |
| params = parse_parameters_from_path(input_dir) |
| if 'ow' not in params: |
| print(f"❌ Error: can not from '{input_dir}' get 'ow'"); return |
| compression_bit_threshold = params['ow'] |
| bits_per_compressed = params['bits_per_compressed'] |
| compression_key = construct_compression_key(params) |
| print(f"compress Key is: '{compression_key}'") |
| print(f"params: compression_bit_threshold={compression_bit_threshold}, bits_per_compressed={bits_per_compressed}") |
|
|
| |
| jsonl_files = [] |
| for root, _, files in os.walk(input_dir): |
| for file in files: |
| if file.endswith('.jsonl') and file.startswith('ocp'): |
| jsonl_files.append(os.path.join(root, file)) |
| if not jsonl_files: |
| print(f"❌ Error: no more'{input_dir}' .jsonl"); return |
| print(f"🔍 Find {len(jsonl_files)} .jsonl files to address") |
| |
| |
| if max_files > 0: |
| jsonl_files = jsonl_files[:max_files] |
| print(f"🔍 小批次模式:仅处理 {len(jsonl_files)} 个文件,每个文件最多 {max_lines} 行") |
|
|
| |
| |
| sequence_to_raw_map = defaultdict(list) |
| total_lines = 0 |
| total_mismatches = 0 |
| key_not_found_count = 0 |
| decode_errors = 0 |
| total_failed = 0 |
| print("🚀 Start addressing all, build global token -> raw_chunk_list map...") |
| |
| for file_path in tqdm(jsonl_files, desc="Processing files"): |
| with open(file_path, 'r', errors='ignore') as f: |
| line_nums = 0 |
| for line in f: |
| line_nums += 1 |
| if max_lines > 0 and line_nums > max_lines: |
| print(f"📌 文件 {os.path.basename(file_path)} 已处理 {max_lines} 行,停止读取") |
| break |
| total_lines += 1 |
| try: |
| data = json.loads(line) |
| if compression_key not in data or not data[compression_key] or \ |
| 'windows_starts_lens_b64' not in data or not data['windows_starts_lens_b64']: |
| continue |
| |
| required_keys = [compression_key, 'text', 'windows_starts_lens_b64', 'pseudo_lens_per_segment'] |
| if not all(k in data and data[k] for k in required_keys): |
| print(f"some key is not exist") |
| continue |
|
|
| if compression_key not in data: |
| if key_not_found_count == 0: |
| print(f"\n\n--- 调试信息:Key 不匹配 ---") |
| print(f"构建的 Key: '{compression_key}'") |
| print(f"JSON中的可用 Keys: {list(data.keys())}") |
| print("---------------------------------") |
| key_not_found_count += 1 |
| continue |
| |
| |
| b64_decoded_bytes = base64.b64decode(data[compression_key]) |
| mixed_pseudo_bytes = packed_bytes_to_pseudo(b64_decoded_bytes) |
|
|
| |
| raw_text_bytes = data['text'].encode('utf-8') |
| all_segments = unpack_windows(raw_text_bytes, data['windows_starts_lens_b64']) |
| pseudo_lens = data["pseudo_lens_per_segment"] |
|
|
| if len(pseudo_lens) != len(all_segments): |
| raise ValueError("Metadata length mismatch between pseudo_lens and all_segments") |
|
|
|
|
| |
| |
|
|
| |
| pseudo_ptr = 0 |
| for i in range(len(all_segments)): |
| raw_chunk, indicator = all_segments[i] |
| segment_pseudo_len = pseudo_lens[i] |
| |
| if pseudo_ptr >= len(mixed_pseudo_bytes): |
| raise ValueError("Pseudo bytes stream exhausted prematurely.") |
| if indicator == 0: |
| pseudo_ptr += segment_pseudo_len |
| elif indicator == 1: |
| |
| |
| |
| |
| |
| |
| |
|
|
| chunk_len = len(raw_chunk) |
| if pseudo_ptr + segment_pseudo_len > len(mixed_pseudo_bytes): |
| raise ValueError("Pseudo bytes stream exhausted for a window.") |
| |
| |
| current_raw_bytes = list(raw_chunk) |
| |
| current_pseudo_sequence = mixed_pseudo_bytes[pseudo_ptr : pseudo_ptr + segment_pseudo_len] |
|
|
| |
| while current_pseudo_sequence and current_pseudo_sequence[-1] < 256: |
| last_pseudo = current_pseudo_sequence.pop() |
| last_raw = current_raw_bytes.pop() |
| |
| assert last_pseudo == last_raw, "Mismatch in raw tail" |
| |
| if current_pseudo_sequence: |
| pure_token_sequence = tuple(current_pseudo_sequence) |
| pure_raw_chunk = bytes(current_raw_bytes) |
| sequence_to_raw_map[pure_token_sequence].append(pure_raw_chunk) |
| |
| pseudo_ptr += segment_pseudo_len |
| total_processed += 1 |
| |
| except Exception: |
| total_failed += 1 |
| continue |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| print(f"✅ All file addressed. Total address{total_lines:,} lines") |
| if key_not_found_count > 0: print(f" {key_not_found_count:,} 行因 key 不匹配被跳过。") |
| if total_mismatches > 0: print(f" {total_mismatches:,} 行因窗口与token数不匹配被跳过。") |
| if decode_errors > 0: print(f" {decode_errors:,} 行因解码错误被跳过。") |
| print(f" Global mapping Finished, all find {len(sequence_to_raw_map):,} unique token。") |
| print("\n🔍 Start analysis token-level collusion...") |
| analysis_results = [] |
| all_distances = [] |
| |
| |
| for token_sequence, raw_chunks_list in tqdm(sequence_to_raw_map.items(), desc="Analyzing collisions"): |
| unique_raw_chunks = list(set(raw_chunks_list)) |
| if len(unique_raw_chunks) > 1: |
| raw_strings = [c.decode('utf-8', 'replace') for c in unique_raw_chunks] |
| distances, pair_details = [], [] |
| for str1, str2 in combinations(raw_strings, 2): |
| dist = Levenshtein.distance(str1, str2) |
| distances.append(dist) |
| all_distances.extend(distances) |
| analysis_results.append({ |
| "colliding_token_sequence": list(token_sequence), |
| "num_raw_variants": len(unique_raw_chunks), |
| "raw_chunk_variants": raw_strings, |
| "levenshtein_analysis": { |
| "distances": distances, |
| "average_distance": np.mean(distances) if distances else 0, |
| "max_distance": max(distances) if distances else 0, |
| "min_distance": min(distances) if distances else 0, |
| } |
| }) |
| print(f"✅ Finshed analysising. Total {len(analysis_results):,} token collusion.") |
| if not analysis_results: |
| print("🎉 Congradulations! no find token-level collusions."); return |
| os.makedirs(output_dir, exist_ok=True) |
| output_json_path = os.path.join(output_dir, "token_collision_report.json") |
| analysis_results.sort(key=lambda x: x['levenshtein_analysis']['average_distance'], reverse=True) |
| with open(output_json_path, 'w', encoding='utf-8') as f: |
| json.dump(analysis_results, f, indent=2, ensure_ascii=False) |
| print(f"\n💾 Analysis is saved to: {output_json_path}") |
| print("\n📋 Sampel(Avg order):") |
| for i, result in enumerate(analysis_results[:5]): |
| print("-" * 20) |
| print(f"Sample {i+1}:") |
| print(f" Collusion Token: {result['colliding_token_sequence']}") |
| print(f" To {result['num_raw_variants']} diff raw bytes") |
| print(f" Avg Distance: {result['levenshtein_analysis']['average_distance']:.2f}") |
| print(f" Raw 1: {repr(result['raw_chunk_variants'][0][:80])}") |
| print(f" Raw 2: {repr(result['raw_chunk_variants'][1][:80])}") |
| output_plot_path = os.path.join(output_dir, "token_collision_levenshtein_distribution.png") |
| plt.style.use('seaborn-v0_8-whitegrid') |
| fig, ax = plt.subplots(figsize=(12, 7)) |
| if all_distances: |
| sns.histplot(all_distances, bins=max(50, min(len(set(all_distances)), 100)), kde=False, ax=ax) |
| stats_text = (f"Total Colliding Pairs: {len(all_distances):,}\n" |
| f"Mean Distance: {np.mean(all_distances):.2f}\n" |
| f"Median Distance: {np.median(all_distances):.2f}\n" |
| f"Max Distance: {np.max(all_distances):,}") |
| ax.text(0.95, 0.95, stats_text, transform=ax.transAxes, fontsize=10, |
| verticalalignment='top', horizontalalignment='right', |
| bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5)) |
| else: |
| ax.text(0.5, 0.5, "No collisions found.", transform=ax.transAxes, fontsize=15, |
| verticalalignment='center', horizontalalignment='center') |
| ax.set_title('Levenshtein Distance between Raw Chunks with Same Compressed Token (Entire Dataset)', fontsize=14) |
| ax.set_xlabel('Levenshtein Distance', fontsize=12) |
| ax.set_ylabel('Frequency (Number of Pairs)', fontsize=12) |
| ax.set_yscale('log') |
| plt.tight_layout() |
| plt.savefig(output_plot_path) |
| print(f"📊 Lev distance is saved to: {output_plot_path}") |
|
|
|
|
| |
| if __name__ == "__main__": |
| try: |
| from tqdm import tqdm |
| except ImportError: |
| print("pip install tqdm") |
| def tqdm(iterable, *args, **kwargs): |
| return iterable |
| parser = argparse.ArgumentParser( |
| description="check all token-level Compression collusion", |
| formatter_class=argparse.RawTextHelpFormatter |
| ) |
| parser.add_argument("input_dir", type=str, help="including .jsonl data input die。") |
| parser.add_argument("-o", "--output_dir", type=str, default="analysis_output_token_collision", |
| help="store output") |
|
|
| parser.add_argument("--max_files", type=int, default=-1, |
| help="set max addressing files") |
| parser.add_argument("--max_lines", type=int, default=-1, |
| help="the most addressing line") |
| args = parser.parse_args() |
| analyze_token_collisions_in_directory( |
| args.input_dir, |
| args.output_dir, |
| max_files=args.max_files, |
| max_lines=args.max_lines) |
| |
|
|