Byte-lingua-code / analysis_dynamic_dis.py
2ira's picture
offline_compression_graph_code
72c0672 verified
import json
import base64
import argparse
import os
from collections import defaultdict
from itertools import combinations
import re
from tqdm import tqdm
from typing import List, Dict, Any, Tuple
import argparse
# before running: pip install python-Levenshtein
# pip install matplotlib
# pip install seaborn
# numpy
# version_1 : for dir, complete matched
try:
import Levenshtein
except ImportError:
print("❌ errord: 'python-Levenshtein' have not be installed")
print("run:pip install python-Levenshtein")
exit(1)
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def vread(buf: bytes, i: int):
shift = val = 0
while True:
b = buf[i]
i += 1
val |= (b & 0x7F) << shift
if b < 0x80:
return val, i
shift += 7
def unpack_windows(input_bytes: bytes, b64_stream: str) -> List[Tuple[bytes, int]]:
try:
buf, i, cursor, byte_windows = base64.b64decode(b64_stream), 0, 0, []
while i < len(buf):
gap, i = vread(buf, i)
size, i = vread(buf, i)
start = cursor + gap
if gap > 0: byte_windows.append((input_bytes[cursor:start], 0))
end = start + size
byte_windows.append((input_bytes[start:end], 1))
cursor = end
if cursor < len(input_bytes): byte_windows.append((input_bytes[cursor:], 0))
return byte_windows
except (base64.binascii.Error, IndexError): return []
def packed_bytes_to_pseudo(b: bytes) -> list[int]:
out, acc, bits = [], 0, 0
for byte in b:
acc |= byte << bits
bits += 8
while bits >= 9:
out.append(acc & 0x1FF)
acc >>= 9
bits -= 9
return out
# decompress strs to list of (bytes, int) tuples
def decompress_windows_starts_lens(b64_stream: str) -> tuple[list[int], list[int]]:
try:
buf = base64.b64decode(b64_stream)
i = 0
cursor= 0
starts, lens = [], []
while i < len(buf):
gap, i = vread(buf, i)
size, i = vread(buf, i)
start = cursor + gap
length = size
starts.append(start)
lens.append(length)
cursor = start + length
return starts, lens
except (base64.binascii.Error, IndexError):
# decode failed
return [], []
def parse_parameters_from_path(file_path: str) -> dict:
"""parse parameters"""
params = {}
base_name = os.path.basename(os.path.normpath(file_path))
parts = base_name.split('_')
for part in parts:
if '-' in part:
# key-value , e.g., "iterative-true"
key, value = part.split('-', 1)
params[key.lower()] = value.lower()
else:
# _keyvalue , e.g., "ow20"
# 使用 re.match 确保只从开头匹配
match = re.match(r'([a-zA-Z]+)(\d+)', part)
if match:
key, value = match.groups()
params[key.lower()] = value
params['bits_per_compressed'] = 10
print(f"From path '{base_name}' to parse params: {params}")
return params
def construct_compression_key(params: dict) -> str:
"""construct compressed data key。"""
ow = params.get('ow', 20)
escape_fb = 'True' if params.get('escapefb', 'false') == 'true' else 'False'
iterative = 'True' if params.get('iterative', 'true') == 'true' else 'False'
force_padding = 'True' if params.get('forcepadding', 'false') == 'true' else 'False'
key = f"m1_ac_ow{ow}_escapefb-{escape_fb}_iterative-{iterative}_forcepadding-{force_padding}"
print(f"Compress data Key is: '{key}'")
return key
def analyze_token_collisions_in_directory(input_dir: str,
output_dir: str,
compression_offset: int = 256,
max_files: int = -1,
max_lines: int = -1):
"""for dir token-level collusion """
if not os.path.isdir(input_dir):
print(f"❌ Error: input file is not valid '{input_dir}'"); return
# from dir get params
params = parse_parameters_from_path(input_dir)
if 'ow' not in params:
print(f"❌ Error: can not from '{input_dir}' get 'ow'"); return
compression_bit_threshold = params['ow']
bits_per_compressed = params['bits_per_compressed']
compression_key = construct_compression_key(params)
print(f"compress Key is: '{compression_key}'")
print(f"params: compression_bit_threshold={compression_bit_threshold}, bits_per_compressed={bits_per_compressed}")
# get all .jsonl
jsonl_files = []
for root, _, files in os.walk(input_dir):
for file in files:
if file.endswith('.jsonl') and file.startswith('ocp'):
jsonl_files.append(os.path.join(root, file))
if not jsonl_files:
print(f"❌ Error: no more'{input_dir}' .jsonl"); return
print(f"🔍 Find {len(jsonl_files)} .jsonl files to address")
# small batch debug
if max_files > 0:
jsonl_files = jsonl_files[:max_files] # 只取前max_files个文件
print(f"🔍 小批次模式:仅处理 {len(jsonl_files)} 个文件,每个文件最多 {max_lines} 行")
# global mapping
#token_to_raw_map = defaultdict(list)
sequence_to_raw_map = defaultdict(list) # tuple to list of raw chunks
total_lines = 0
total_mismatches = 0
key_not_found_count = 0
decode_errors = 0
total_failed = 0
print("🚀 Start addressing all, build global token -> raw_chunk_list map...")
for file_path in tqdm(jsonl_files, desc="Processing files"):
with open(file_path, 'r', errors='ignore') as f:
line_nums = 0
for line in f:
line_nums += 1
if max_lines > 0 and line_nums > max_lines:
print(f"📌 文件 {os.path.basename(file_path)} 已处理 {max_lines} 行,停止读取")
break
total_lines += 1
try:
data = json.loads(line)
if compression_key not in data or not data[compression_key] or \
'windows_starts_lens_b64' not in data or not data['windows_starts_lens_b64']:
continue
required_keys = [compression_key, 'text', 'windows_starts_lens_b64', 'pseudo_lens_per_segment']
if not all(k in data and data[k] for k in required_keys):
print(f"some key is not exist")
continue
if compression_key not in data:
if key_not_found_count == 0:
print(f"\n\n--- 调试信息:Key 不匹配 ---")
print(f"构建的 Key: '{compression_key}'")
print(f"JSON中的可用 Keys: {list(data.keys())}")
print("---------------------------------")
key_not_found_count += 1
continue
# 1. parse windows to get mixed data
b64_decoded_bytes = base64.b64decode(data[compression_key])
mixed_pseudo_bytes = packed_bytes_to_pseudo(b64_decoded_bytes)
# 2.unpack_window to split original texts
raw_text_bytes = data['text'].encode('utf-8')
all_segments = unpack_windows(raw_text_bytes, data['windows_starts_lens_b64'])
pseudo_lens = data["pseudo_lens_per_segment"]
if len(pseudo_lens) != len(all_segments):
raise ValueError("Metadata length mismatch between pseudo_lens and all_segments")
# compressed bytes: 1 2 3 355 356 1 2 3
# raw bytes: 1 2 3 17 18 19 1 2 3
# 3.use ptr to find each compressed position
pseudo_ptr = 0
for i in range(len(all_segments)):
raw_chunk, indicator = all_segments[i]
segment_pseudo_len = pseudo_lens[i]
if pseudo_ptr >= len(mixed_pseudo_bytes):
raise ValueError("Pseudo bytes stream exhausted prematurely.")
if indicator == 0: # skip bytes
pseudo_ptr += segment_pseudo_len
elif indicator == 1: # compressed windows
# get the compressed token
# token = mixed_pseudo_bytes[pseudo_ptr]
# if token < 256:
# raise ValueError(f"Expected a compressed token (>=256), but got {token}")
# # get mapping
# token_to_raw_map[token].append(raw_chunk)
# pseudo_ptr += 1
chunk_len = len(raw_chunk)
if pseudo_ptr + segment_pseudo_len > len(mixed_pseudo_bytes):
raise ValueError("Pseudo bytes stream exhausted for a window.")
# extract common lists
# raw bytes
current_raw_bytes = list(raw_chunk)
# mixed bytes
current_pseudo_sequence = mixed_pseudo_bytes[pseudo_ptr : pseudo_ptr + segment_pseudo_len]
# delete common end bytes which is less than 256
while current_pseudo_sequence and current_pseudo_sequence[-1] < 256:
last_pseudo = current_pseudo_sequence.pop()
last_raw = current_raw_bytes.pop()
# verify last raw bytes must be same
assert last_pseudo == last_raw, "Mismatch in raw tail"
# after clearify: set map
if current_pseudo_sequence:
pure_token_sequence = tuple(current_pseudo_sequence)
pure_raw_chunk = bytes(current_raw_bytes)
sequence_to_raw_map[pure_token_sequence].append(pure_raw_chunk)
# 更新指针
pseudo_ptr += segment_pseudo_len
total_processed += 1
except Exception:
total_failed += 1
continue
# can not only use start and length to get all window because we only record the compressed window
# so we must use start_pos and cursor to skip the raw bytes
# pseudo_tokens = [t for t in mixed_pseudo_bytes if t >= 256]
# # checksum
# if len(starts) != len(pseudo_tokens):
# total_mismatches += 1
# continue
# # mapping
# raw_text_bytes = data['text'].encode('utf-8')
# for i, token in enumerate(pseudo_tokens):
# start, length = starts[i], lens[i]
# if start + length <= len(raw_text_bytes):
# raw_chunk = raw_text_bytes[start : start + length]
# token_to_raw_map[token].append(raw_chunk)
# except (json.JSONDecodeError, TypeError, KeyError, base64.binascii.Error, struct.error):
# if total_lines % 100000 == 0:
# print(f"⚠️ 处理第{total_lines}行时出错: {e}")
# continue # skip questions
print(f"✅ All file addressed. Total address{total_lines:,} lines")
if key_not_found_count > 0: print(f" {key_not_found_count:,} 行因 key 不匹配被跳过。")
if total_mismatches > 0: print(f" {total_mismatches:,} 行因窗口与token数不匹配被跳过。")
if decode_errors > 0: print(f" {decode_errors:,} 行因解码错误被跳过。")
print(f" Global mapping Finished, all find {len(sequence_to_raw_map):,} unique token。")
print("\n🔍 Start analysis token-level collusion...")
analysis_results = []
all_distances = []
# process
for token_sequence, raw_chunks_list in tqdm(sequence_to_raw_map.items(), desc="Analyzing collisions"):
unique_raw_chunks = list(set(raw_chunks_list))
if len(unique_raw_chunks) > 1:
raw_strings = [c.decode('utf-8', 'replace') for c in unique_raw_chunks]
distances, pair_details = [], []
for str1, str2 in combinations(raw_strings, 2):
dist = Levenshtein.distance(str1, str2)
distances.append(dist)
all_distances.extend(distances)
analysis_results.append({
"colliding_token_sequence": list(token_sequence),
"num_raw_variants": len(unique_raw_chunks),
"raw_chunk_variants": raw_strings,
"levenshtein_analysis": {
"distances": distances,
"average_distance": np.mean(distances) if distances else 0,
"max_distance": max(distances) if distances else 0,
"min_distance": min(distances) if distances else 0,
}
})
print(f"✅ Finshed analysising. Total {len(analysis_results):,} token collusion.")
if not analysis_results:
print("🎉 Congradulations! no find token-level collusions."); return
os.makedirs(output_dir, exist_ok=True)
output_json_path = os.path.join(output_dir, "token_collision_report.json")
analysis_results.sort(key=lambda x: x['levenshtein_analysis']['average_distance'], reverse=True)
with open(output_json_path, 'w', encoding='utf-8') as f:
json.dump(analysis_results, f, indent=2, ensure_ascii=False)
print(f"\n💾 Analysis is saved to: {output_json_path}")
print("\n📋 Sampel(Avg order):")
for i, result in enumerate(analysis_results[:5]):
print("-" * 20)
print(f"Sample {i+1}:")
print(f" Collusion Token: {result['colliding_token_sequence']}")
print(f" To {result['num_raw_variants']} diff raw bytes")
print(f" Avg Distance: {result['levenshtein_analysis']['average_distance']:.2f}")
print(f" Raw 1: {repr(result['raw_chunk_variants'][0][:80])}")
print(f" Raw 2: {repr(result['raw_chunk_variants'][1][:80])}")
output_plot_path = os.path.join(output_dir, "token_collision_levenshtein_distribution.png")
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(12, 7))
if all_distances:
sns.histplot(all_distances, bins=max(50, min(len(set(all_distances)), 100)), kde=False, ax=ax)
stats_text = (f"Total Colliding Pairs: {len(all_distances):,}\n"
f"Mean Distance: {np.mean(all_distances):.2f}\n"
f"Median Distance: {np.median(all_distances):.2f}\n"
f"Max Distance: {np.max(all_distances):,}")
ax.text(0.95, 0.95, stats_text, transform=ax.transAxes, fontsize=10,
verticalalignment='top', horizontalalignment='right',
bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))
else:
ax.text(0.5, 0.5, "No collisions found.", transform=ax.transAxes, fontsize=15,
verticalalignment='center', horizontalalignment='center')
ax.set_title('Levenshtein Distance between Raw Chunks with Same Compressed Token (Entire Dataset)', fontsize=14)
ax.set_xlabel('Levenshtein Distance', fontsize=12)
ax.set_ylabel('Frequency (Number of Pairs)', fontsize=12)
ax.set_yscale('log')
plt.tight_layout()
plt.savefig(output_plot_path)
print(f"📊 Lev distance is saved to: {output_plot_path}")
## python analysis_dynamic_dis.py /mnt/hdfs/linzheng/data/ocpython_subsampled_50G_entropy90_splits_chunk512_ow20_iterative-true_forcepadding-true_merged_ac
if __name__ == "__main__":
try:
from tqdm import tqdm
except ImportError:
print("pip install tqdm")
def tqdm(iterable, *args, **kwargs):
return iterable
parser = argparse.ArgumentParser(
description="check all token-level Compression collusion",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument("input_dir", type=str, help="including .jsonl data input die。")
parser.add_argument("-o", "--output_dir", type=str, default="analysis_output_token_collision",
help="store output")
parser.add_argument("--max_files", type=int, default=-1,
help="set max addressing files")
parser.add_argument("--max_lines", type=int, default=-1,
help="the most addressing line")
args = parser.parse_args()
analyze_token_collisions_in_directory(
args.input_dir,
args.output_dir,
max_files=args.max_files,
max_lines=args.max_lines)