| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Utility functions.""" |
| import numpy as np |
| import torch |
|
|
| def bits_to_bytes_padding_to_threshold(bits: str, bit_threshold: int) -> tuple[bytes, int]: |
| """Returns the bytes representation of bitstream and number of padded bits.""" |
| |
| padded_bits = bits.ljust((bit_threshold + 7) // 8 * 8, '0') |
| num_padded_bits = len(padded_bits) - len(bits) |
|
|
| |
| chunks = [padded_bits[i : i + 8] for i in range(0, len(padded_bits), 8)] |
|
|
| |
| bytes_data = bytes([int(chunk, base=2) for chunk in chunks]) |
|
|
| return bytes_data, num_padded_bits |
|
|
|
|
| def bits_to_bytes(bits: str) -> tuple[bytes, int]: |
| """Returns the bytes representation of bitstream and number of padded bits.""" |
| |
| padded_bits = bits.ljust((len(bits) + 7) // 8 * 8, '0') |
| num_padded_bits = len(padded_bits) - len(bits) |
|
|
| |
| chunks = [padded_bits[i : i + 8] for i in range(0, len(padded_bits), 8)] |
|
|
| |
| bytes_data = bytes([int(chunk, base=2) for chunk in chunks]) |
|
|
| return bytes_data, num_padded_bits |
|
|
|
|
| def bytes_to_bits(data: bytes, num_padded_bits: int = 0) -> str: |
| """Returns the bitstream of bytes data accounting for padded bits.""" |
| if num_padded_bits == 0: |
| return ''.join([bin(byte)[2:].zfill(8) for byte in data]) |
| else: |
| return ''.join([bin(byte)[2:].zfill(8) for byte in data])[:-num_padded_bits] |
|
|
|
|
| def right_shift_bytes_by_one(data: bytes) -> tuple[bytes, int]: |
| """Returns right-shifted bytes, i.e., divided by 2, and the number of bytes. |
| |
| Our language models were trained on ASCII data. However, not all bytes can be |
| decoded to ASCII, so we set the most significant bit (MSB) to 0, to ensure |
| that we can decode the data to ASCII. |
| |
| However, for certain data types (e.g., images), masking the MSB and leaving |
| the rest of the byte unchanged will destroy the structure of the data. Thus, |
| we instead divide the number by two (i.e., we shift the bits to the right by |
| one). |
| |
| Args: |
| data: The bytes to be shifted. |
| """ |
| return bytes([byte >> 1 for byte in data]), len(data) |
|
|
|
|
| def zero_most_significant_bit_if_not_ascii_decodable( |
| data: bytes, |
| ) -> tuple[bytes, int]: |
| """Returns ascii-decodable data & the number of zeroed most significant bits. |
| |
| Our language models were trained on ASCII data. However, not all bytes can be |
| decoded to ASCII, so we set the most significant bit (MSB) to 0, to ensure |
| that we can decode the data to ASCII. |
| |
| Args: |
| data: The bytes to be shifted. |
| """ |
| masked_bits = 0 |
| masked_data = list() |
|
|
| for byte in data: |
| if chr(byte).isascii(): |
| masked_data.append(byte) |
| else: |
| masked_bits += 1 |
| masked_data.append(byte & 0x7F) |
|
|
| return bytes(masked_data), masked_bits |
|
|
|
|
| def normalize_pdf_for_arithmetic_coding(pdf): |
| """Normalizes the probabilities for arithmetic coding. |
| |
| Arithmetic coding converts the floating-point pdf to integers to avoid |
| numerical issues. To that end, all pdf values need to be larger than the |
| machine epsilon (to yield different integer values) and the sum of the pdf |
| cannot exceed 1 (minus some precision tolerance). |
| |
| Args: |
| pdf: The probabilities to be normalized. |
| |
| Returns: |
| The normalized probabilities. |
| """ |
| machine_epsilon = np.finfo(np.float32).eps |
| |
| pdf = pdf / np.cumsum(pdf)[-1] |
| |
| pdf = (1 - 2 * pdf.shape[0] * machine_epsilon) * pdf + machine_epsilon |
| return pdf |
|
|
| def batched_normalize_pdf_for_arithmetic_coding(pdfs: torch.Tensor): |
| """Normalizes the probabilities for arithmetic coding. |
| |
| Args: |
| pdfs: The probabilities to be normalized. |
| """ |
| |
| |
| |
| pdfs = pdfs.to(torch.float16).to(torch.float32) |
| machine_epsilon = torch.finfo(torch.float32).eps |
| pdfs = pdfs / pdfs.sum(dim=-1, keepdim=True) |
| pdfs = (1 - 2 * pdfs.shape[-1] * machine_epsilon) * pdfs + machine_epsilon |
| return pdfs |
|
|