| import base64 |
| import math |
| import json |
| from typing import List, Tuple, Dict, Any, Iterator |
| import torch |
| from torch.utils.data import Dataset, IterableDataset |
|
|
|
|
| def vread(buf: bytes, i: int): |
| shift = val = 0 |
| while True: |
| b = buf[i] |
| i += 1 |
| val |= (b & 0x7F) << shift |
| if b < 0x80: |
| return val, i |
| shift += 7 |
|
|
| def vwrite(v: int, out: bytearray): |
| while True: |
| byte = v & 0x7F |
| v >>= 7 |
| out.append(byte | 0x80 if v else byte) |
| if not v: |
| break |
|
|
| def compress_windows_starts_lens(starts, lens): |
| buf = bytearray() |
| cursor = 0 |
| for s, L in zip(starts, lens): |
| gap = s - cursor |
| vwrite(gap, buf) |
| vwrite(L, buf) |
| cursor = s + L |
| return base64.b64encode(buf).decode("ascii") |
|
|
| def decompress_windows_starts_lens(b64_stream): |
| buf = base64.b64decode(b64_stream) |
| i = 0 |
| cursor= 0 |
| starts, lens = [], [] |
| while i < len(buf): |
| gap, i = vread(buf, i) |
| size, i = vread(buf, i) |
| start = cursor + gap |
| length = size |
| starts.append(start) |
| lens.append(length) |
| cursor = start + length |
| return starts, lens |
|
|
| def unpack_windows( |
| input_bytes: bytes, |
| b64_stream: str, |
| ) -> List[Tuple[bytes, int]]: |
| """ |
| Returns |
| - byte_windows: list of (bytes, int) tuples, where the int is 0 if the bytes is raw and 1 if the bytes is compressed |
| """ |
| buf = base64.b64decode(b64_stream) |
| i = 0 |
| cursor = 0 |
| byte_windows = [] |
|
|
| while i < len(buf): |
| gap, i = vread(buf, i) |
| size, i = vread(buf, i) |
| start = cursor + gap |
| if gap > 0: |
| hole = input_bytes[cursor:start] |
| byte_windows.append((hole, 0)) |
| length = size |
| end = start + length |
| win = input_bytes[start:end] |
| byte_windows.append((win, 1)) |
| cursor = end |
|
|
| if cursor < len(input_bytes): |
| hole = input_bytes[cursor:] |
| byte_windows.append((hole, 0)) |
|
|
| return byte_windows |
|
|
|
|
| def pseudo_to_packed_bytes(lst: list[int]) -> bytes: |
| out = bytearray() |
| acc = bits = 0 |
| for v in lst: |
| acc |= (v & 0x1FF) << bits |
| bits += 9 |
| while bits >= 8: |
| out.append(acc & 0xFF) |
| acc >>= 8 |
| bits -= 8 |
| if bits: |
| out.append(acc) |
| return bytes(out) |
|
|
| def packed_bytes_to_pseudo(b: bytes) -> list[int]: |
| out, acc, bits = [], 0, 0 |
| for byte in b: |
| acc |= byte << bits |
| bits += 8 |
| while bits >= 9: |
| out.append(acc & 0x1FF) |
| acc >>= 9 |
| bits -= 9 |
| return out |
|
|
| def pad_batch(batch: List[bytes]): |
| batch_tensors = [torch.tensor(data, dtype=torch.int64) for data in batch] |
| lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64) |
| padded_batch = torch.nn.utils.rnn.pad_sequence( |
| batch_tensors, |
| batch_first=True, |
| padding_value=0, |
| padding_side="right" |
| ) |
| return padded_batch, lengths |
|
|
| class JsonlShardedDataset(Dataset): |
| def __init__( |
| self, |
| file_path: str, |
| current_proc_rank: int = 0, |
| total_procs: int = 1, |
| ) -> None: |
|
|
| assert 0 <= current_proc_rank < total_procs, "rank must be in [0, world_size)" |
| self.current_proc_rank = current_proc_rank |
| self.total_procs = total_procs |
|
|
| |
| with open(file_path, "r", encoding="utf-8") as f: |
| full_data: List[Dict[str, Any]] = [json.loads(line) for line in f] |
|
|
| |
| total = len(full_data) |
| per_proc = math.ceil(total / total_procs) |
| start = current_proc_rank * per_proc |
| end = min(start + per_proc, total) |
| self.data = full_data[start:end] |
|
|
| def __len__(self) -> int: |
| return len(self.data) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| return self.data[idx] |
|
|
| class InterleavedJsonlDataset(IterableDataset): |
| """ |
| An iterable-style dataset for reading a large JSONL file using an |
| interleaving/striding pattern, without yielding state information. |
| |
| This is designed for multi-process data loading. Each process reads the |
| entire file but only processes lines that match its rank (offset). |
| For `N` total processes (world_size), process `r` (rank) will read |
| lines r, r+N, r+2N, ... (0-indexed). |
| |
| This method ensures an even distribution of lines across processes. |
| |
| Args: |
| file_path (str): Path to the JSONL file. |
| rank (int): The rank of the current process, used as the offset. |
| world_size (int): The total number of processes, used as the block_size/stride. |
| """ |
| def __init__( |
| self, |
| file_path: str, |
| rank: int, |
| world_size: int, |
| ) -> None: |
| super().__init__() |
| |
| if not (0 <= rank < world_size): |
| raise ValueError(f"Rank must be in [0, {world_size-1}], but got {rank}") |
|
|
| self.file_path = file_path |
| self.offset = rank |
| self.block_size = world_size |
|
|
| def __iter__(self) -> Iterator[Dict[str, Any]]: |
| """ |
| The iterator method that yields the parsed JSON data for the assigned lines. |
| """ |
| try: |
| with open(self.file_path, "r", encoding="utf-8") as f: |
| |
| |
| for line_number, line in enumerate(f): |
| |
| if (line_number % self.block_size) == self.offset: |
| try: |
| |
| yield json.loads(line) |
| except json.JSONDecodeError: |
| |
| |
| print(f"Warning: Rank {self.offset} could not decode JSON on line ~{line_number+1}. Skipping.") |
| continue |
| except Exception as e: |
| print(f"Error in worker {self.offset}: {e}") |
| raise |
|
|
|
|
| def batched_m1_compress_predict_fn(model): |
| def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor: |
| if input_tensor.dim() == 1: |
| input_tensor = input_tensor.unsqueeze(0) |
| with torch.no_grad(): |
| |
| logits = model(input_tensor, **kwargs) |
| logits = logits[..., :256] |
| logits = logits.float() |
| assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values." |
| probs = torch.softmax(logits, dim=-1) |
| return probs |
| |
| return predict_fn |
|
|
|
|
| def find_next_batch_range(all_windows, start_idx, max_m1_batch_size, get_batch_size_for_length_fn): |
| M = len(all_windows) |
| if start_idx >= M: |
| return start_idx, start_idx |
|
|
| first_window_len = len(all_windows[start_idx]) |
| base_batch_size = get_batch_size_for_length_fn(first_window_len, max_m1_batch_size) |
|
|
| low = start_idx |
| high = min(start_idx + base_batch_size, M) |
| high_batch_size = get_batch_size_for_length_fn(len(all_windows[high - 1]), max_m1_batch_size) |
| if high_batch_size == base_batch_size: |
| return start_idx, high |
|
|
| search_low = low |
| search_high = high |
| while search_low < search_high: |
| mid = search_low + (search_high - search_low) // 2 |
| mid_window_len = len(all_windows[mid]) |
| if get_batch_size_for_length_fn(mid_window_len, max_m1_batch_size) == base_batch_size: |
| |
| |
| search_low = mid + 1 |
| else: |
| |
| |
| |
| search_high = mid |
| end_idx = search_low |
| if end_idx == start_idx: |
| return start_idx, start_idx + 1 |
| else: |
| return start_idx, end_idx |
|
|