Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from fla.layers.simple_gla import SimpleGatedLinearAttention | |
| from fla.models.utils import Cache | |
| from sympy import num_digits | |
| from torch import nn | |
| from tts.layers.attention import CrossAttention | |
| from tts.layers.ffn import SwiGLU | |
| from .cache_utils import FLACache | |
| from .config import SimpleGLADecoderConfig | |
| from .registry import register_decoder | |
| from .shortconv import ShortConvBlock | |
| if "GRAD_CKPT" in os.environ: | |
| def maybe_grad_ckpt(f): | |
| def grad_ckpt_f(*args, **kwargs): | |
| return torch.utils.checkpoint.checkpoint( | |
| f, *args, **kwargs, use_reentrant=False | |
| ) | |
| return grad_ckpt_f | |
| else: | |
| def maybe_grad_ckpt(f): | |
| return f | |
| class SimpleGLABlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| layer_idx: int, | |
| expand_k: float, | |
| expand_v: float, | |
| use_short_conv: bool, | |
| ffn_expansion_factor: int, | |
| ): | |
| super().__init__() | |
| self.tmix = SimpleGatedLinearAttention( | |
| hidden_size=dim, | |
| num_heads=num_heads, | |
| layer_idx=layer_idx, | |
| ) | |
| self.cmix = SwiGLU(dim, ffn_expansion_factor) | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| def forward( | |
| self, | |
| x, | |
| freqs: torch.Tensor | None = None, | |
| text_freqs: torch.Tensor | None = None, | |
| cache: Cache | None = None, | |
| ): | |
| x = ( | |
| self.tmix( | |
| self.norm1(x), | |
| past_key_values=cache, | |
| use_cache=cache is not None, | |
| )[0] | |
| + x | |
| ) | |
| x = self.cmix(self.norm2(x)) + x | |
| return x | |
| class DecoderBlockWithOptionalCrossAttention(nn.Module): | |
| def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None): | |
| super().__init__() | |
| self.decoder_block = decoder_block | |
| self.crossatt = crossatt | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| encoder_output: torch.Tensor | None = None, | |
| freqs: torch.Tensor | None = None, | |
| text_freqs: torch.Tensor | None = None, | |
| cache: Cache | None = None, | |
| selfatt_mask: torch.Tensor | None = None, | |
| crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None, | |
| ) -> torch.Tensor: | |
| x = self.decoder_block( | |
| x, | |
| freqs=freqs, | |
| cache=cache, | |
| ) | |
| if type(crossatt_mask) is list: | |
| crossatt_mask = crossatt_mask[self.decoder_block.tmix.layer_idx] | |
| if self.crossatt is not None: | |
| x = x + self.crossatt( | |
| x, | |
| k=encoder_output, | |
| text_freqs=text_freqs, | |
| mask=crossatt_mask, | |
| cache=cache, | |
| ) | |
| return x | |
| class SimpleGLADecoder(nn.Module): | |
| config = SimpleGLADecoderConfig | |
| def __init__(self, cfg: SimpleGLADecoderConfig): | |
| super().__init__() | |
| assert cfg.dim % cfg.num_heads == 0, "num_heads should divide dim" | |
| assert cfg.blind_crossatt + (cfg.listen_read_crossatt is not None) < 2, ( | |
| "at most one specialized cross-attention" | |
| ) | |
| self.head_dim = cfg.dim // cfg.num_heads | |
| self.num_heads = cfg.num_heads | |
| def simple_gla_block(i): | |
| conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers | |
| if i in conv_layers: | |
| return ShortConvBlock( | |
| dim=cfg.dim, | |
| kernel_size=4, | |
| ffn_expansion_factor=cfg.ffn_expansion_factor, | |
| layer_idx=i, | |
| use_fast_conv1d=True, | |
| ) | |
| else: | |
| return SimpleGLABlock( | |
| dim=cfg.dim, | |
| num_heads=cfg.num_heads, | |
| layer_idx=i, | |
| expand_k=cfg.expand_k, | |
| expand_v=cfg.expand_v, | |
| use_short_conv=cfg.use_short_conv, | |
| ffn_expansion_factor=cfg.ffn_expansion_factor, | |
| ) | |
| def crossatt_block(i): | |
| if i in cfg.crossatt_layer_idx: | |
| return CrossAttention( | |
| dim=cfg.dim, | |
| num_heads=cfg.crossatt_num_heads, | |
| dropout=cfg.crossatt_dropout, | |
| layer_idx=i, | |
| ) | |
| else: | |
| return None | |
| self.decoder_layers = nn.ModuleList( | |
| [ | |
| DecoderBlockWithOptionalCrossAttention( | |
| simple_gla_block(i), | |
| crossatt_block(i), | |
| ) | |
| for i in range(cfg.num_layers) | |
| ] | |
| ) | |
| def forward( | |
| self, | |
| encoder_output: torch.Tensor, | |
| decoder_input: torch.Tensor, | |
| crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None, | |
| text_ids: torch.Tensor | None = None, | |
| cache: FLACache | None = None, | |
| ): | |
| x = decoder_input | |
| text_freqs = None | |
| for layer in self.decoder_layers: | |
| x = maybe_grad_ckpt(layer)( | |
| x, | |
| encoder_output, | |
| text_freqs=text_freqs, | |
| cache=cache, | |
| crossatt_mask=crossatt_mask, | |
| ) | |
| return x | |
| def init_cache(self, max_seq_len, device): | |
| return FLACache(num_states=len(self.decoder_layers) + 1) | |
| def init_initial_state(self, batch_size=1, scale=1e-2, device="cpu"): | |
| return tuple( | |
| nn.Parameter( | |
| torch.randn( | |
| batch_size, | |
| self.num_heads, | |
| self.head_dim, | |
| self.head_dim, | |
| device=device, | |
| ) | |
| * scale | |
| ) | |
| for _ in range(len(self.decoder_layers)) | |
| ) | |
| def init_initial_state_lora(self, lora:int=1, batch_size: int = 1, scale: float=1e-2, device: str="cpu"): | |
| return tuple( | |
| ( | |
| nn.Parameter( | |
| torch.randn( | |
| batch_size, | |
| self.num_heads, | |
| self.head_dim, | |
| lora, | |
| device=device, | |
| ) | |
| * scale | |
| ), | |
| nn.Parameter( | |
| torch.randn( | |
| batch_size, | |
| self.num_heads, | |
| lora, | |
| self.head_dim, | |
| device=device, | |
| ) | |
| * scale | |
| ) | |
| ) | |
| for _ in range(len(self.decoder_layers)) | |
| ) | |
| def _get_query(self, audio_inputs: torch.Tensor, layer_idx: int): | |
| assert self.decoder_layers[layer_idx].crossatt is not None | |
| x = audio_inputs | |
| for _, layer in zip(range(layer_idx - 1), self.decoder_layers): | |
| x = layer(x, None) | |
| return self.decoder_layers[layer_idx].crossatt._query(x) | |
| def forward_first_n_layers( | |
| self, | |
| encoder_output: torch.Tensor, | |
| decoder_input: torch.Tensor, | |
| n_first_layers: int, | |
| crossatt_mask: torch.Tensor | None = None, | |
| cache: FLACache | None = None, | |
| ): | |
| x = decoder_input | |
| if self.text_freqs_embd is not None: | |
| text_freqs = torch.arange(encoder_output.shape[1], device=x.device)[None, :] | |
| text_freqs = self.text_freqs_embd(text_freqs) | |
| else: | |
| text_freqs = None | |
| for layer in self.decoder_layers[:n_first_layers]: | |
| x = maybe_grad_ckpt(layer)( | |
| x, | |
| encoder_output, | |
| text_freqs=text_freqs, | |
| cache=cache, | |
| crossatt_mask=crossatt_mask, | |
| ) | |
| return x | |
| def prefill( | |
| self, | |
| encoder_output: torch.Tensor, | |
| decoder_input: torch.Tensor, | |
| crossatt_mask: torch.Tensor | None = None, | |
| cache: FLACache | None = None, | |
| ): | |
| return self(encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask) | |
| def decode_one( | |
| self, | |
| encoder_output: torch.Tensor, | |
| decoder_input: torch.Tensor, | |
| cache: Cache, | |
| text_freqs: torch.Tensor | None = None, | |
| crossatt_mask: torch.Tensor | None = None, | |
| ): | |
| x = decoder_input | |
| for layer in self.decoder_layers: | |
| x = layer( | |
| x, | |
| encoder_output, | |
| text_freqs=text_freqs, | |
| cache=cache, | |
| crossatt_mask=crossatt_mask, | |
| ) | |
| return x | |