File size: 13,926 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import json
import base64
from offline_utils import packed_bytes_to_pseudo

def pack_compressed_spans(data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256):
    """
    Convert consecutive compressed values into larger integers.
    
    Args:
        data: List of integers where 0-255 are raw bytes, compression_offset+ are compressed bytes
        bits_per_compressed: Number of bits to use for each packed value
        compression_bit_threshold: Number of bits each compressed value actually uses
        compression_offset: Offset that marks start of compressed values (default 256)
    
    Returns:
        List with consecutive compressed spans packed into larger integers
    """
    if not data:
        return []
    
    result = []
    i = 0

    assert compression_bit_threshold % bits_per_compressed == 0, "compression_bit_threshold must be divisible by bits_per_compressed"
    packing_mask = (1 << bits_per_compressed) - 1
    compression_mask = (1 << compression_bit_threshold) - 1
    
    # Calculate byte-aligned padded size
    padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8
    padded_mask = (1 << padded_compression_bit_threshold) - 1

    padding_bits = padded_compression_bit_threshold - compression_bit_threshold
    
    while i < len(data):
        if data[i] >= compression_offset:
            # Find the end of consecutive compressed bytes
            span_start = i
            while i < len(data) and data[i] >= compression_offset:
                i += 1
            
            # Extract the span of compressed bytes
            compressed_span = data[span_start:i]
            
            base_values = [x - compression_offset for x in compressed_span]
            
            # Process bytes incrementally to avoid large numbers
            bit_buffer = 0
            bits_in_buffer = 0
            packed_values = []
            
            for val in base_values:
                # Add this byte to bit buffer
                bit_buffer = (bit_buffer << 8) | val
                bits_in_buffer += 8
                
                # Extract padded chunks as soon as we have enough bits
                while bits_in_buffer >= padded_compression_bit_threshold:
                    shift_amount = bits_in_buffer - padded_compression_bit_threshold
                    padded_val = (bit_buffer >> shift_amount) & padded_mask
                    
                    # Remove the extracted bits from buffer
                    bit_buffer &= (1 << shift_amount) - 1
                    bits_in_buffer -= padded_compression_bit_threshold
                    
                    # Strip padding by extracting only the meaningful bits
                    extracted_val = (padded_val >> padding_bits) & compression_mask
                    
                    pack_buffer = extracted_val
                    pack_bits = compression_bit_threshold
                    
                    # Pack values as soon as we have enough bits
                    while pack_bits >= bits_per_compressed:
                        pack_shift = pack_bits - bits_per_compressed
                        packed_val = (pack_buffer >> pack_shift) & packing_mask
                        packed_values.append(packed_val + compression_offset)
                        
                        # Remove packed bits from pack buffer
                        pack_buffer &= (1 << pack_shift) - 1
                        pack_bits -= bits_per_compressed
            
                    assert bits_in_buffer == 0, "bits_in_buffer must be 0 after processing compressed span"
                    assert pack_bits == 0, "pack_bits must be 0 after packing"
                    
            result.extend(packed_values)
        else:
            # Raw byte (0-255), keep as is
            result.append(data[i])
            i += 1
    
    return result

def unpack_compressed_spans(packed_data, bits_per_compressed: int, compression_bit_threshold: int, compression_offset: int = 256):
    """
    Reverse operation: unpack larger integers back to consecutive compressed bytes.
    
    Args:
        packed_data: List with packed compressed spans
        bits_per_compressed: Number of bits used for packing
        compression_bit_threshold: Number of bits each compressed value actually uses
        compression_offset: Offset used for compressed values
    
    Returns:
        Original format with consecutive compressed bytes
    """
    result = []
    i = 0
    
    # Calculate byte-aligned padded size
    padded_compression_bit_threshold = ((compression_bit_threshold + 7) // 8) * 8
    padding_bits = padded_compression_bit_threshold - compression_bit_threshold
    
    while i < len(packed_data):
        if packed_data[i] >= compression_offset:  # Start of compressed span
            # Find consecutive packed values
            span_start = i
            while i < len(packed_data) and packed_data[i] >= compression_offset:
                i += 1
            
            packed_span = packed_data[span_start:i]
            base_values = [x - compression_offset for x in packed_span]
            
            # Unpack using two-phase process to handle padding
            unpacked_bytes = []
            bit_buffer = 0
            bits_in_buffer = 0
            
            for val in base_values:
                # Add this packed value to our bit buffer
                bit_buffer = (bit_buffer << bits_per_compressed) | val
                bits_in_buffer += bits_per_compressed
                
                # Extract compression_bit_threshold values as soon as we have enough bits
                while bits_in_buffer >= compression_bit_threshold:
                    # Extract the top compression_bit_threshold bits
                    shift_amount = bits_in_buffer - compression_bit_threshold
                    compressed_val = (bit_buffer >> shift_amount) & ((1 << compression_bit_threshold) - 1)
                    
                    # Remove the extracted bits from buffer
                    bit_buffer &= (1 << shift_amount) - 1
                    bits_in_buffer -= compression_bit_threshold
                    
                    # Add padding back to make it byte-aligned
                    padded_val = compressed_val << padding_bits
                    
                    # Convert padded value back to bytes
                    bytes_needed = padded_compression_bit_threshold // 8
                    for byte_idx in range(bytes_needed):
                        shift = (bytes_needed - 1 - byte_idx) * 8
                        byte_val = (padded_val >> shift) & 0xFF
                        unpacked_bytes.append(byte_val + compression_offset)
            
            # Verify all bits were processed cleanly
            assert bits_in_buffer == 0, "bits_in_buffer must be 0 after unpacking compressed span"
            
            result.extend(unpacked_bytes)
        else:
            # Raw byte, keep as is
            result.append(packed_data[i])
            i += 1
    
    return result

def run_test_case(test_name: str, data: list, bits_per_compressed: int, compression_bit_threshold: int):
    """Run a single test case with comprehensive validation."""
    print(f"πŸ§ͺ {test_name}")
    print(f"   Original: {data}")
    print(f"   Config: bits_per_compressed={bits_per_compressed}, compression_bit_threshold={compression_bit_threshold}")
    
    try:
        # Test packing
        packed = pack_compressed_spans(data, bits_per_compressed, compression_bit_threshold)
        print(f"   Packed:   {packed}")
        
        # Test unpacking
        unpacked = unpack_compressed_spans(packed, bits_per_compressed, compression_bit_threshold)
        print(f"   Unpacked: {unpacked}")
        
        # Verify round-trip
        success = data == unpacked
        print(f"   Result:   {'βœ… PASS' if success else '❌ FAIL'}")
        
        # Show compression stats
        original_compressed = len([x for x in data if x >= 256])
        packed_compressed = len([x for x in packed if x >= 256])
        if original_compressed > 0:
            ratio = original_compressed / packed_compressed if packed_compressed > 0 else 0
            print(f"   Stats:    {original_compressed} β†’ {packed_compressed} compressed values ({ratio:.2f}x)")
        
        return success
        
    except Exception as e:
        print(f"   Result:   ❌ ERROR: {e}")
        return False


def test_packing_comprehensive():
    from m1_compression import utils
    import random
    def random_bytes_generator(n: int, bit_threshold: int):
        ret = []
        length = random.randint(n // 2, n)
        for _ in range(length):
            bits = ""
            for _ in range(bit_threshold):
                bits += "0" if random.random() < 0.5 else "1"
            compressed_bytes, _ = utils.bits_to_bytes_padding_to_threshold(bits, bit_threshold)
            ret.extend([c + 256 for c in list(compressed_bytes)])
        ret.extend([random.randint(0, 255)])
        return ret
    
    """Comprehensive test suite for packing functions."""
    print("=" * 60)
    print("πŸš€ COMPREHENSIVE PACKING TESTS")
    print("=" * 60)
    
    test_results = []
    
    # Test 1: Basic functionality - 16-bit alignment (no padding)
    test_results.append(run_test_case(
        "Basic 16-bit packing (no padding)",
        random_bytes_generator(100, 16),
        bits_per_compressed=16,
        compression_bit_threshold=16
    ))
    print()
    
    # Test 2: 12-bit values with padding (12 bits stored in 16 bits)
    test_results.append(run_test_case(
        "12-bit values with 4-bit padding",
        random_bytes_generator(100, 12),
        bits_per_compressed=12,
        compression_bit_threshold=12
    ))
    print()
    
    # Test 3: 20-bit values with padding (20 bits stored in 24 bits)
    test_results.append(run_test_case(
        "20-bit values with 4-bit padding",
        random_bytes_generator(100, 20),
        bits_per_compressed=20,
        compression_bit_threshold=20
    ))
    print()
    
    # Test 5: Edge case - single compressed byte
    test_results.append(run_test_case(
        "Single compressed byte",
        [100, 256, 200],
        bits_per_compressed=8,
        compression_bit_threshold=8
    ))
    print()
    
    # Test 6: Edge case - no compressed bytes
    test_results.append(run_test_case(
        "No compressed bytes",
        [100, 200, 50, 150],
        bits_per_compressed=16,
        compression_bit_threshold=16
    ))
    print()
    
    # Test 7: Edge case - all compressed bytes
    test_results.append(run_test_case(
        "All compressed bytes",
        [256, 257, 258, 259, 260, 261],
        bits_per_compressed=8,
        compression_bit_threshold=8
    ))
    print()
    
    # Test 8: Mixed compression ratios
    test_results.append(run_test_case(
        "24-bit to 12-bit packing (2:1 ratio)",
        random_bytes_generator(100, 24),
        bits_per_compressed=12,
        compression_bit_threshold=24
    ))
    print()
    
    # Summary
    passed = sum(test_results)
    total = len(test_results)
    print("=" * 60)
    print(f"πŸ“Š TEST SUMMARY: {passed}/{total} tests passed")
    print("=" * 60)
    
    if passed == total:
        print("πŸŽ‰ All tests passed! The implementation is working correctly.")
    else:
        print("⚠️  Some tests failed. Please review the implementation.")
    
    return passed, total

def test_real_data():
    print("=" * 40)
    print("πŸ”§ REAL DATA TESTS")
    print("=" * 40)

    key = "m1_ac_ow20_escapefb-False_iterative-True"
    
    with open("output_compress/m1.chunk.0_out_0_out_0_writer_0.jsonl", "r") as f:
        for i, line in enumerate(f):
            data = json.loads(line)
            # NOTE: for visualization purposes, we replace values > 256 with byte '_'
            # bytes_array = packed_bytes_to_pseudo(base64.b64decode(data[key]))
            # bytes_array = [b if b < 256 else ord('_') for b in bytes_array]
            # bytes_string = bytes(bytes_array).decode("utf-8", errors="replace")
            # print(bytes_string)

            # extract bit_threshold in key
            key_splits = key.split("_")
            bit_threshold = None
            for k in key_splits:
                if k.startswith(
                    
                ):
                    bit_threshold = int(k[len("ow"):])
                    break
            assert bit_threshold is not None
            print(f"Bit threshold: {bit_threshold}")

            # NEW: Apply the packing function to the original bytes_array (before replacement)
            original_bytes_array = packed_bytes_to_pseudo(base64.b64decode(data[key]))

            run_test_case(
                f"Packing {bit_threshold}-bit values",
                original_bytes_array,
                10,
                bit_threshold
            )

            if i > 4:
                break


def test_error_conditions():
    """Test error conditions and edge cases."""
    print("\nπŸ”§ ERROR CONDITION TESTS")
    print("=" * 40)
    
    # Test invalid bit alignment
    try:
        pack_compressed_spans([256, 257], 10, 15)  # 15 % 10 != 0
        print("❌ Should have failed on invalid bit alignment")
    except AssertionError:
        print("βœ… Correctly caught invalid bit alignment")
    
    # Test empty data
    result = pack_compressed_spans([], 16, 16)
    print(f"βœ… Empty data handling: {result == []}")
    
    print()


if __name__ == "__main__":
    # Run comprehensive tests
    passed, total = test_packing_comprehensive()
    
    # Final result
    if passed == total:
        print("πŸ† ALL TESTS COMPLETED SUCCESSFULLY!")
    else:
        print("πŸ’₯ SOME TESTS FAILED!")

    test_real_data()
    
    # Run error condition tests
    test_error_conditions()