File size: 17,366 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import json
import base64
import argparse
import os
from collections import defaultdict
from itertools import combinations
import re
from tqdm import tqdm
from typing import List, Dict, Any, Tuple
import argparse

# before running: pip install python-Levenshtein
# pip install matplotlib
# pip install seaborn
# numpy

# version_1 : for dir, complete matched

try:
    import Levenshtein
except ImportError:
    print("❌ errord: 'python-Levenshtein' have not be installed")
    print("run:pip install python-Levenshtein")
    exit(1)

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


def vread(buf: bytes, i: int):
    shift = val = 0
    while True:
        b = buf[i]
        i += 1
        val |= (b & 0x7F) << shift
        if b < 0x80:
            return val, i
        shift += 7

def unpack_windows(input_bytes: bytes, b64_stream: str) -> List[Tuple[bytes, int]]:
    try:
        buf, i, cursor, byte_windows = base64.b64decode(b64_stream), 0, 0, []
        while i < len(buf):
            gap,  i = vread(buf, i)
            size, i = vread(buf, i)
            start   = cursor + gap
            if gap > 0: byte_windows.append((input_bytes[cursor:start], 0))
            end     = start + size
            byte_windows.append((input_bytes[start:end], 1))
            cursor  = end
        if cursor < len(input_bytes): byte_windows.append((input_bytes[cursor:], 0))
        return byte_windows
    except (base64.binascii.Error, IndexError): return []


def packed_bytes_to_pseudo(b: bytes) -> list[int]:
    out, acc, bits = [], 0, 0
    for byte in b:
        acc |= byte << bits
        bits += 8
        while bits >= 9:
            out.append(acc & 0x1FF)
            acc >>= 9
            bits -= 9
    return out

# decompress strs to list of (bytes, int) tuples
def decompress_windows_starts_lens(b64_stream: str) -> tuple[list[int], list[int]]:
    try:
        buf   = base64.b64decode(b64_stream)
        i     = 0
        cursor= 0
        starts, lens = [], []
        while i < len(buf):
            gap,  i = vread(buf, i)
            size, i = vread(buf, i)
            start   = cursor + gap
            length  = size
            starts.append(start)
            lens.append(length)
            cursor  = start + length
        return starts, lens
    except (base64.binascii.Error, IndexError):
        # decode failed
        return [], []


def parse_parameters_from_path(file_path: str) -> dict:
    """parse parameters"""
    params = {}
    base_name = os.path.basename(os.path.normpath(file_path))
    parts = base_name.split('_')
    for part in parts:
        if '-' in part:
            #  key-value , e.g., "iterative-true"
            key, value = part.split('-', 1)
            params[key.lower()] = value.lower()
        else:
            # _keyvalue , e.g., "ow20"
            # 使用 re.match 确保只从开头匹配
            match = re.match(r'([a-zA-Z]+)(\d+)', part)
            if match:
                key, value = match.groups()
                params[key.lower()] = value
    
    params['bits_per_compressed'] = 10 
    print(f"From path '{base_name}' to parse params: {params}")
    return params


def construct_compression_key(params: dict) -> str:
    """construct compressed data key。"""
    ow = params.get('ow', 20)
    escape_fb = 'True' if params.get('escapefb', 'false') == 'true' else 'False'
    iterative = 'True' if params.get('iterative', 'true') == 'true' else 'False'
    force_padding = 'True' if params.get('forcepadding', 'false') == 'true' else 'False'
    key = f"m1_ac_ow{ow}_escapefb-{escape_fb}_iterative-{iterative}_forcepadding-{force_padding}"
    print(f"Compress data Key is: '{key}'")
    return key


def analyze_token_collisions_in_directory(input_dir: str, 
                                          output_dir: str, 
                                          compression_offset: int = 256,
                                          max_files: int = -1,
                                          max_lines: int = -1):
    """for dir token-level collusion """
    if not os.path.isdir(input_dir):
        print(f"❌ Error: input file is not valid '{input_dir}'"); return
    # from dir get params
    params = parse_parameters_from_path(input_dir)
    if 'ow' not in params:
        print(f"❌ Error: can not from '{input_dir}' get 'ow'"); return
    compression_bit_threshold = params['ow']
    bits_per_compressed = params['bits_per_compressed'] 
    compression_key = construct_compression_key(params)
    print(f"compress Key is: '{compression_key}'")
    print(f"params: compression_bit_threshold={compression_bit_threshold}, bits_per_compressed={bits_per_compressed}")

    # get all .jsonl
    jsonl_files = []
    for root, _, files in os.walk(input_dir):
        for file in files:
            if file.endswith('.jsonl') and file.startswith('ocp'):
                jsonl_files.append(os.path.join(root, file))
    if not jsonl_files:
        print(f"❌ Error: no more'{input_dir}' .jsonl"); return
    print(f"🔍 Find {len(jsonl_files)} .jsonl files to address")
    
    # small batch debug
    if max_files > 0:
        jsonl_files = jsonl_files[:max_files]  # 只取前max_files个文件
        print(f"🔍 小批次模式:仅处理 {len(jsonl_files)} 个文件,每个文件最多 {max_lines} 行")

    # global mapping
    #token_to_raw_map = defaultdict(list)
    sequence_to_raw_map = defaultdict(list) # tuple to list of raw chunks
    total_lines = 0
    total_mismatches = 0
    key_not_found_count = 0
    decode_errors = 0
    total_failed = 0
    print("🚀 Start addressing all, build global token -> raw_chunk_list map...")
    
    for file_path in tqdm(jsonl_files, desc="Processing files"):
        with open(file_path, 'r', errors='ignore') as f:
            line_nums = 0
            for line in f:
                line_nums += 1
                if max_lines > 0 and line_nums > max_lines:
                    print(f"📌 文件 {os.path.basename(file_path)} 已处理 {max_lines} 行,停止读取")
                    break
                total_lines += 1
                try:
                    data = json.loads(line)
                    if compression_key not in data or not data[compression_key] or \
                       'windows_starts_lens_b64' not in data or not data['windows_starts_lens_b64']:
                        continue
                    
                    required_keys = [compression_key, 'text', 'windows_starts_lens_b64', 'pseudo_lens_per_segment']
                    if not all(k in data and data[k] for k in required_keys):
                        print(f"some key is not exist")
                        continue

                    if compression_key not in data:
                        if key_not_found_count == 0:
                            print(f"\n\n--- 调试信息:Key 不匹配 ---")
                            print(f"构建的 Key: '{compression_key}'")
                            print(f"JSON中的可用 Keys: {list(data.keys())}")
                            print("---------------------------------")
                        key_not_found_count += 1
                        continue
                    
                    # 1. parse windows to get mixed data
                    b64_decoded_bytes = base64.b64decode(data[compression_key])
                    mixed_pseudo_bytes = packed_bytes_to_pseudo(b64_decoded_bytes)

                    # 2.unpack_window to split original texts
                    raw_text_bytes = data['text'].encode('utf-8')
                    all_segments = unpack_windows(raw_text_bytes, data['windows_starts_lens_b64'])
                    pseudo_lens = data["pseudo_lens_per_segment"]

                    if len(pseudo_lens) != len(all_segments):
                        raise ValueError("Metadata length mismatch between pseudo_lens and all_segments")


                    # compressed bytes: 1 2 3 355 356 1 2 3
                    # raw bytes:        1 2 3 17 18 19 1 2 3

                    # 3.use ptr to find each compressed position
                    pseudo_ptr = 0
                    for i in range(len(all_segments)):
                        raw_chunk, indicator = all_segments[i]
                        segment_pseudo_len = pseudo_lens[i]
 
                        if pseudo_ptr >= len(mixed_pseudo_bytes):
                            raise ValueError("Pseudo bytes stream exhausted prematurely.")
                        if indicator == 0: # skip bytes
                            pseudo_ptr += segment_pseudo_len
                        elif indicator == 1: # compressed windows
                            # get the compressed token
                            # token = mixed_pseudo_bytes[pseudo_ptr]
                            # if token < 256:
                            #     raise ValueError(f"Expected a compressed token (>=256), but got {token}")
                            # # get mapping
                            # token_to_raw_map[token].append(raw_chunk)
                            # pseudo_ptr += 1

                            chunk_len = len(raw_chunk)
                            if pseudo_ptr + segment_pseudo_len > len(mixed_pseudo_bytes):
                                raise ValueError("Pseudo bytes stream exhausted for a window.")
                            # extract common lists
                            # raw bytes
                            current_raw_bytes = list(raw_chunk)
                            # mixed bytes
                            current_pseudo_sequence = mixed_pseudo_bytes[pseudo_ptr : pseudo_ptr + segment_pseudo_len]

                            # delete common end bytes which is less than 256
                            while current_pseudo_sequence and current_pseudo_sequence[-1] < 256:
                                last_pseudo = current_pseudo_sequence.pop()
                                last_raw = current_raw_bytes.pop()
                                # verify last raw bytes must be same
                                assert last_pseudo == last_raw, "Mismatch in raw tail"
                            # after clearify: set map
                            if current_pseudo_sequence:
                                pure_token_sequence = tuple(current_pseudo_sequence)
                                pure_raw_chunk = bytes(current_raw_bytes)
                                sequence_to_raw_map[pure_token_sequence].append(pure_raw_chunk)
                            # 更新指针
                            pseudo_ptr += segment_pseudo_len
                    total_processed += 1
                
                except Exception:
                    total_failed += 1
                    continue

                    # can not only use start and length to get all window because we only record the compressed window
                    # so we must use start_pos and cursor to skip the raw bytes
                    
                    # pseudo_tokens = [t for t in mixed_pseudo_bytes if t >= 256]
                    
                    # # checksum
                    # if len(starts) != len(pseudo_tokens):
                    #     total_mismatches += 1
                    #     continue

                    # # mapping
                    # raw_text_bytes = data['text'].encode('utf-8')
                    # for i, token in enumerate(pseudo_tokens):
                    #     start, length = starts[i], lens[i]
                    #     if start + length <= len(raw_text_bytes):
                    #         raw_chunk = raw_text_bytes[start : start + length]
                    #         token_to_raw_map[token].append(raw_chunk)
                # except (json.JSONDecodeError, TypeError, KeyError, base64.binascii.Error, struct.error):
                #     if total_lines % 100000 == 0:
                #         print(f"⚠️ 处理第{total_lines}行时出错: {e}")
                #     continue # skip questions
    print(f"✅ All file addressed. Total address{total_lines:,} lines")
    if key_not_found_count > 0: print(f"   {key_not_found_count:,} 行因 key 不匹配被跳过。")
    if total_mismatches > 0: print(f"   {total_mismatches:,} 行因窗口与token数不匹配被跳过。")
    if decode_errors > 0: print(f"   {decode_errors:,} 行因解码错误被跳过。")
    print(f"   Global mapping Finished, all find {len(sequence_to_raw_map):,} unique token。")
    print("\n🔍 Start analysis token-level collusion...")
    analysis_results = []
    all_distances = []
    
    # process
    for token_sequence, raw_chunks_list in tqdm(sequence_to_raw_map.items(), desc="Analyzing collisions"):
        unique_raw_chunks = list(set(raw_chunks_list))
        if len(unique_raw_chunks) > 1:
            raw_strings = [c.decode('utf-8', 'replace') for c in unique_raw_chunks]
            distances, pair_details = [], []
            for str1, str2 in combinations(raw_strings, 2):
                dist = Levenshtein.distance(str1, str2)
                distances.append(dist)
            all_distances.extend(distances)
            analysis_results.append({
                "colliding_token_sequence": list(token_sequence),
                "num_raw_variants": len(unique_raw_chunks),
                "raw_chunk_variants": raw_strings,
                "levenshtein_analysis": {
                    "distances": distances,
                    "average_distance": np.mean(distances) if distances else 0,
                    "max_distance": max(distances) if distances else 0,
                    "min_distance": min(distances) if distances else 0,
                }
            })
    print(f"✅ Finshed analysising. Total {len(analysis_results):,} token collusion.")
    if not analysis_results:
        print("🎉 Congradulations! no find token-level collusions."); return
    os.makedirs(output_dir, exist_ok=True)
    output_json_path = os.path.join(output_dir, "token_collision_report.json")
    analysis_results.sort(key=lambda x: x['levenshtein_analysis']['average_distance'], reverse=True)
    with open(output_json_path, 'w', encoding='utf-8') as f:
        json.dump(analysis_results, f, indent=2, ensure_ascii=False)
    print(f"\n💾 Analysis is saved to: {output_json_path}")
    print("\n📋 Sampel(Avg order):")
    for i, result in enumerate(analysis_results[:5]):
        print("-" * 20)
        print(f"Sample {i+1}:")
        print(f"  Collusion Token: {result['colliding_token_sequence']}")
        print(f"  To {result['num_raw_variants']} diff raw bytes")
        print(f"  Avg Distance: {result['levenshtein_analysis']['average_distance']:.2f}")
        print(f"  Raw 1: {repr(result['raw_chunk_variants'][0][:80])}")
        print(f"  Raw 2: {repr(result['raw_chunk_variants'][1][:80])}")
    output_plot_path = os.path.join(output_dir, "token_collision_levenshtein_distribution.png")
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(12, 7))
    if all_distances:
        sns.histplot(all_distances, bins=max(50, min(len(set(all_distances)), 100)), kde=False, ax=ax)
        stats_text = (f"Total Colliding Pairs: {len(all_distances):,}\n"
                      f"Mean Distance: {np.mean(all_distances):.2f}\n"
                      f"Median Distance: {np.median(all_distances):.2f}\n"
                      f"Max Distance: {np.max(all_distances):,}")
        ax.text(0.95, 0.95, stats_text, transform=ax.transAxes, fontsize=10,
                verticalalignment='top', horizontalalignment='right',
                bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))
    else:
        ax.text(0.5, 0.5, "No collisions found.", transform=ax.transAxes, fontsize=15,
                verticalalignment='center', horizontalalignment='center')
    ax.set_title('Levenshtein Distance between Raw Chunks with Same Compressed Token (Entire Dataset)', fontsize=14)
    ax.set_xlabel('Levenshtein Distance', fontsize=12)
    ax.set_ylabel('Frequency (Number of Pairs)', fontsize=12)
    ax.set_yscale('log')
    plt.tight_layout()
    plt.savefig(output_plot_path)
    print(f"📊 Lev distance is saved to: {output_plot_path}")


##  python analysis_dynamic_dis.py /mnt/hdfs/linzheng/data/ocpython_subsampled_50G_entropy90_splits_chunk512_ow20_iterative-true_forcepadding-true_merged_ac
if __name__ == "__main__":
    try:
        from tqdm import tqdm
    except ImportError:
        print("pip install tqdm")
        def tqdm(iterable, *args, **kwargs):
            return iterable
    parser = argparse.ArgumentParser(
        description="check all token-level Compression collusion",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument("input_dir", type=str, help="including .jsonl data input die。")
    parser.add_argument("-o", "--output_dir", type=str, default="analysis_output_token_collision",
                        help="store output")

    parser.add_argument("--max_files", type=int, default=-1, 
                        help="set max addressing files")
    parser.add_argument("--max_lines", type=int, default=-1, 
                        help="the most addressing line")
    args = parser.parse_args()
    analyze_token_collisions_in_directory(
        args.input_dir, 
        args.output_dir,
        max_files=args.max_files,
        max_lines=args.max_lines)