| |
| |
| |
| |
| |
|
|
| """ |
| Base class for all quantizers. |
| """ |
|
|
| from dataclasses import dataclass, field |
| import typing as tp |
|
|
| import torch |
| from torch import nn |
|
|
|
|
| @dataclass |
| class QuantizedResult: |
| x: torch.Tensor |
| codes: torch.Tensor |
| bandwidth: torch.Tensor |
| penalty: tp.Optional[torch.Tensor] = None |
| metrics: dict = field(default_factory=dict) |
|
|
|
|
| class BaseQuantizer(nn.Module): |
| """Base class for quantizers. |
| """ |
|
|
| def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: |
| """ |
| Given input tensor x, returns first the quantized (or approximately quantized) |
| representation along with quantized codes, bandwidth, and any penalty term for the loss. |
| Finally, this returns a dict of metrics to update logging etc. |
| Frame rate must be passed so that the bandwidth is properly computed. |
| """ |
| raise NotImplementedError() |
|
|
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| """Encode a given input tensor with the specified sample rate at the given bandwidth. |
| """ |
| raise NotImplementedError() |
|
|
| def decode(self, codes: torch.Tensor) -> torch.Tensor: |
| """Decode the given codes to the quantized representation. |
| """ |
| raise NotImplementedError() |
|
|
| @property |
| def total_codebooks(self): |
| """Total number of codebooks. |
| """ |
| raise NotImplementedError() |
|
|
| @property |
| def num_codebooks(self): |
| """Number of active codebooks. |
| """ |
| raise NotImplementedError() |
|
|
| def set_num_codebooks(self, n: int): |
| """Set the number of active codebooks. |
| """ |
| raise NotImplementedError() |
|
|
|
|
| class DummyQuantizer(BaseQuantizer): |
| """Fake quantizer that actually does not perform any quantization. |
| """ |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x: torch.Tensor, frame_rate: int): |
| q = x.unsqueeze(1) |
| return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) |
|
|
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| """Encode a given input tensor with the specified sample rate at the given bandwidth. |
| In the case of the DummyQuantizer, the codes are actually identical |
| to the input and resulting quantized representation as no quantization is done. |
| """ |
| return x.unsqueeze(1) |
|
|
| def decode(self, codes: torch.Tensor) -> torch.Tensor: |
| """Decode the given codes to the quantized representation. |
| In the case of the DummyQuantizer, the codes are actually identical |
| to the input and resulting quantized representation as no quantization is done. |
| """ |
| return codes.squeeze(1) |
|
|
| @property |
| def total_codebooks(self): |
| """Total number of codebooks. |
| """ |
| return 1 |
|
|
| @property |
| def num_codebooks(self): |
| """Total number of codebooks. |
| """ |
| return self.total_codebooks |
|
|
| def set_num_codebooks(self, n: int): |
| """Set the number of active codebooks. |
| """ |
| raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") |
|
|