Byte-lingua-code / offline_utils.py
2ira's picture
offline_compression_graph_code
72c0672 verified
import base64
import math
import json
from typing import List, Tuple, Dict, Any, Iterator
import torch
from torch.utils.data import Dataset, IterableDataset
def vread(buf: bytes, i: int):
shift = val = 0
while True:
b = buf[i]
i += 1
val |= (b & 0x7F) << shift
if b < 0x80:
return val, i
shift += 7
def vwrite(v: int, out: bytearray):
while True:
byte = v & 0x7F
v >>= 7
out.append(byte | 0x80 if v else byte)
if not v:
break
def compress_windows_starts_lens(starts, lens):
buf = bytearray()
cursor = 0
for s, L in zip(starts, lens):
gap = s - cursor
vwrite(gap, buf)
vwrite(L, buf)
cursor = s + L
return base64.b64encode(buf).decode("ascii")
def decompress_windows_starts_lens(b64_stream):
buf = base64.b64decode(b64_stream)
i = 0
cursor= 0
starts, lens = [], []
while i < len(buf):
gap, i = vread(buf, i)
size, i = vread(buf, i)
start = cursor + gap
length = size
starts.append(start)
lens.append(length)
cursor = start + length
return starts, lens
def unpack_windows(
input_bytes: bytes,
b64_stream: str,
) -> List[Tuple[bytes, int]]:
"""
Returns
- byte_windows: list of (bytes, int) tuples, where the int is 0 if the bytes is raw and 1 if the bytes is compressed
"""
buf = base64.b64decode(b64_stream)
i = 0
cursor = 0
byte_windows = []
while i < len(buf):
gap, i = vread(buf, i)
size, i = vread(buf, i)
start = cursor + gap
if gap > 0:
hole = input_bytes[cursor:start]
byte_windows.append((hole, 0))
length = size
end = start + length
win = input_bytes[start:end]
byte_windows.append((win, 1))
cursor = end
if cursor < len(input_bytes):
hole = input_bytes[cursor:]
byte_windows.append((hole, 0))
return byte_windows
def pseudo_to_packed_bytes(lst: list[int]) -> bytes:
out = bytearray()
acc = bits = 0
for v in lst:
acc |= (v & 0x1FF) << bits
bits += 9
while bits >= 8:
out.append(acc & 0xFF)
acc >>= 8
bits -= 8
if bits: # flush tail
out.append(acc)
return bytes(out)
def packed_bytes_to_pseudo(b: bytes) -> list[int]:
out, acc, bits = [], 0, 0
for byte in b:
acc |= byte << bits
bits += 8
while bits >= 9:
out.append(acc & 0x1FF)
acc >>= 9
bits -= 9
return out
def pad_batch(batch: List[bytes]):
batch_tensors = [torch.tensor(data, dtype=torch.int64) for data in batch]
lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64)
padded_batch = torch.nn.utils.rnn.pad_sequence(
batch_tensors,
batch_first=True,
padding_value=0,
padding_side="right"
)
return padded_batch, lengths
class JsonlShardedDataset(Dataset):
def __init__(
self,
file_path: str,
current_proc_rank: int = 0,
total_procs: int = 1,
) -> None:
assert 0 <= current_proc_rank < total_procs, "rank must be in [0, world_size)"
self.current_proc_rank = current_proc_rank
self.total_procs = total_procs
# -- load the whole file once (fast for < few-GB files) -------------
with open(file_path, "r", encoding="utf-8") as f:
full_data: List[Dict[str, Any]] = [json.loads(line) for line in f]
# -- pick the slice that belongs to *this* process ------------------
total = len(full_data)
per_proc = math.ceil(total / total_procs)
start = current_proc_rank * per_proc
end = min(start + per_proc, total)
self.data = full_data[start:end]
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, Any]:
return self.data[idx]
class InterleavedJsonlDataset(IterableDataset):
"""
An iterable-style dataset for reading a large JSONL file using an
interleaving/striding pattern, without yielding state information.
This is designed for multi-process data loading. Each process reads the
entire file but only processes lines that match its rank (offset).
For `N` total processes (world_size), process `r` (rank) will read
lines r, r+N, r+2N, ... (0-indexed).
This method ensures an even distribution of lines across processes.
Args:
file_path (str): Path to the JSONL file.
rank (int): The rank of the current process, used as the offset.
world_size (int): The total number of processes, used as the block_size/stride.
"""
def __init__(
self,
file_path: str,
rank: int,
world_size: int,
) -> None:
super().__init__()
if not (0 <= rank < world_size):
raise ValueError(f"Rank must be in [0, {world_size-1}], but got {rank}")
self.file_path = file_path
self.offset = rank
self.block_size = world_size
def __iter__(self) -> Iterator[Dict[str, Any]]:
"""
The iterator method that yields the parsed JSON data for the assigned lines.
"""
try:
with open(self.file_path, "r", encoding="utf-8") as f:
# We use a simple line counter to determine which lines to process.
# The line_number is 0-indexed.
for line_number, line in enumerate(f):
# Check if the current line number belongs to this process
if (line_number % self.block_size) == self.offset:
try:
# Yield the parsed JSON object
yield json.loads(line)
except json.JSONDecodeError:
# This line is malformed. We can either raise an error
# or, more robustly, just print a warning and skip it.
print(f"Warning: Rank {self.offset} could not decode JSON on line ~{line_number+1}. Skipping.")
continue
except Exception as e:
print(f"Error in worker {self.offset}: {e}")
raise
def batched_m1_compress_predict_fn(model):
def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor:
if input_tensor.dim() == 1:
input_tensor = input_tensor.unsqueeze(0)
with torch.no_grad():
# get logits
logits = model(input_tensor, **kwargs)
logits = logits[..., :256]
logits = logits.float()
assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values."
probs = torch.softmax(logits, dim=-1)
return probs
return predict_fn
def find_next_batch_range(all_windows, start_idx, max_m1_batch_size, get_batch_size_for_length_fn):
M = len(all_windows)
if start_idx >= M:
return start_idx, start_idx
first_window_len = len(all_windows[start_idx])
base_batch_size = get_batch_size_for_length_fn(first_window_len, max_m1_batch_size)
low = start_idx
high = min(start_idx + base_batch_size, M)
high_batch_size = get_batch_size_for_length_fn(len(all_windows[high - 1]), max_m1_batch_size)
if high_batch_size == base_batch_size:
return start_idx, high
search_low = low
search_high = high
while search_low < search_high:
mid = search_low + (search_high - search_low) // 2
mid_window_len = len(all_windows[mid])
if get_batch_size_for_length_fn(mid_window_len, max_m1_batch_size) == base_batch_size:
# This window is valid. The partition point must be to the right of it.
# So, we continue searching in the range [mid + 1, high).
search_low = mid + 1
else:
# This window is NOT valid. It might be the partition point itself,
# or the point is to its left.
# So, we continue searching in the range [low, mid).
search_high = mid
end_idx = search_low
if end_idx == start_idx:
return start_idx, start_idx + 1
else:
return start_idx, end_idx