File size: 5,054 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Utility functions."""
import numpy as np
import torch

def bits_to_bytes_padding_to_threshold(bits: str, bit_threshold: int) -> tuple[bytes, int]:
  """Returns the bytes representation of bitstream and number of padded bits."""
  # Pad the string with zeros if the length is not a multiple of 8.
  padded_bits = bits.ljust((bit_threshold + 7) // 8 * 8, '0')
  num_padded_bits = len(padded_bits) - len(bits)

  # Split the string into 8-bit chunks.
  chunks = [padded_bits[i : i + 8] for i in range(0, len(padded_bits), 8)]

  # Convert each chunk to an integer and then to a byte.
  bytes_data = bytes([int(chunk, base=2) for chunk in chunks])

  return bytes_data, num_padded_bits


def bits_to_bytes(bits: str) -> tuple[bytes, int]:
  """Returns the bytes representation of bitstream and number of padded bits."""
  # Pad the string with zeros if the length is not a multiple of 8.
  padded_bits = bits.ljust((len(bits) + 7) // 8 * 8, '0')
  num_padded_bits = len(padded_bits) - len(bits)

  # Split the string into 8-bit chunks.
  chunks = [padded_bits[i : i + 8] for i in range(0, len(padded_bits), 8)]

  # Convert each chunk to an integer and then to a byte.
  bytes_data = bytes([int(chunk, base=2) for chunk in chunks])

  return bytes_data, num_padded_bits


def bytes_to_bits(data: bytes, num_padded_bits: int = 0) -> str:
  """Returns the bitstream of bytes data accounting for padded bits."""
  if num_padded_bits == 0:
    return ''.join([bin(byte)[2:].zfill(8) for byte in data])
  else:
    return ''.join([bin(byte)[2:].zfill(8) for byte in data])[:-num_padded_bits]


def right_shift_bytes_by_one(data: bytes) -> tuple[bytes, int]:
  """Returns right-shifted bytes, i.e., divided by 2, and the number of bytes.

  Our language models were trained on ASCII data. However, not all bytes can be
  decoded to ASCII, so we set the most significant bit (MSB) to 0, to ensure
  that we can decode the data to ASCII.

  However, for certain data types (e.g., images), masking the MSB and leaving
  the rest of the byte unchanged will destroy the structure of the data. Thus,
  we instead divide the number by two (i.e., we shift the bits to the right by
  one).

  Args:
    data: The bytes to be shifted.
  """
  return bytes([byte >> 1 for byte in data]), len(data)


def zero_most_significant_bit_if_not_ascii_decodable(
    data: bytes,
) -> tuple[bytes, int]:
  """Returns ascii-decodable data & the number of zeroed most significant bits.

  Our language models were trained on ASCII data. However, not all bytes can be
  decoded to ASCII, so we set the most significant bit (MSB) to 0, to ensure
  that we can decode the data to ASCII.

  Args:
    data: The bytes to be shifted.
  """
  masked_bits = 0
  masked_data = list()

  for byte in data:
    if chr(byte).isascii():
      masked_data.append(byte)
    else:
      masked_bits += 1
      masked_data.append(byte & 0x7F)

  return bytes(masked_data), masked_bits


def normalize_pdf_for_arithmetic_coding(pdf):
  """Normalizes the probabilities for arithmetic coding.

  Arithmetic coding converts the floating-point pdf to integers to avoid
  numerical issues. To that end, all pdf values need to be larger than the
  machine epsilon (to yield different integer values) and the sum of the pdf
  cannot exceed 1 (minus some precision tolerance).

  Args:
    pdf: The probabilities to be normalized.

  Returns:
    The normalized probabilities.
  """
  machine_epsilon = np.finfo(np.float32).eps
  # Normalize the probabilities to avoid floating-point errors.
  pdf = pdf / np.cumsum(pdf)[-1]
  # Ensure all probabilities are sufficiently large to yield distinct cdfs.
  pdf = (1 - 2 * pdf.shape[0] * machine_epsilon) * pdf + machine_epsilon
  return pdf

def batched_normalize_pdf_for_arithmetic_coding(pdfs: torch.Tensor):
    """Normalizes the probabilities for arithmetic coding.

    Args:
      pdfs: The probabilities to be normalized.
    """
    # NOTE: this quantization step is to filter out the numerical errors
    # brought by e.g. batch size, sequence length, etc.
    # a more crude approach is to use bfloat16, but fp16 seems sufficient
    pdfs = pdfs.to(torch.float16).to(torch.float32)
    machine_epsilon = torch.finfo(torch.float32).eps
    pdfs = pdfs / pdfs.sum(dim=-1, keepdim=True)
    pdfs = (1 - 2 * pdfs.shape[-1] * machine_epsilon) * pdfs + machine_epsilon
    return pdfs