import os import torch import torch.nn.functional as F from einops import rearrange from fla.layers.simple_gla import SimpleGatedLinearAttention from fla.models.utils import Cache from sympy import num_digits from torch import nn from tts.layers.attention import CrossAttention from tts.layers.ffn import SwiGLU from .cache_utils import FLACache from .config import SimpleGLADecoderConfig from .registry import register_decoder from .shortconv import ShortConvBlock if "GRAD_CKPT" in os.environ: def maybe_grad_ckpt(f): def grad_ckpt_f(*args, **kwargs): return torch.utils.checkpoint.checkpoint( f, *args, **kwargs, use_reentrant=False ) return grad_ckpt_f else: def maybe_grad_ckpt(f): return f class SimpleGLABlock(nn.Module): def __init__( self, dim: int, num_heads: int, layer_idx: int, expand_k: float, expand_v: float, use_short_conv: bool, ffn_expansion_factor: int, ): super().__init__() self.tmix = SimpleGatedLinearAttention( hidden_size=dim, num_heads=num_heads, layer_idx=layer_idx, ) self.cmix = SwiGLU(dim, ffn_expansion_factor) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) def forward( self, x, freqs: torch.Tensor | None = None, text_freqs: torch.Tensor | None = None, cache: Cache | None = None, ): x = ( self.tmix( self.norm1(x), past_key_values=cache, use_cache=cache is not None, )[0] + x ) x = self.cmix(self.norm2(x)) + x return x class DecoderBlockWithOptionalCrossAttention(nn.Module): def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None): super().__init__() self.decoder_block = decoder_block self.crossatt = crossatt def forward( self, x: torch.Tensor, encoder_output: torch.Tensor | None = None, freqs: torch.Tensor | None = None, text_freqs: torch.Tensor | None = None, cache: Cache | None = None, selfatt_mask: torch.Tensor | None = None, crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None, ) -> torch.Tensor: x = self.decoder_block( x, freqs=freqs, cache=cache, ) if type(crossatt_mask) is list: crossatt_mask = crossatt_mask[self.decoder_block.tmix.layer_idx] if self.crossatt is not None: x = x + self.crossatt( x, k=encoder_output, text_freqs=text_freqs, mask=crossatt_mask, cache=cache, ) return x @register_decoder("simple_gla") class SimpleGLADecoder(nn.Module): config = SimpleGLADecoderConfig def __init__(self, cfg: SimpleGLADecoderConfig): super().__init__() assert cfg.dim % cfg.num_heads == 0, "num_heads should divide dim" assert cfg.blind_crossatt + (cfg.listen_read_crossatt is not None) < 2, ( "at most one specialized cross-attention" ) self.head_dim = cfg.dim // cfg.num_heads self.num_heads = cfg.num_heads def simple_gla_block(i): conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers if i in conv_layers: return ShortConvBlock( dim=cfg.dim, kernel_size=4, ffn_expansion_factor=cfg.ffn_expansion_factor, layer_idx=i, use_fast_conv1d=True, ) else: return SimpleGLABlock( dim=cfg.dim, num_heads=cfg.num_heads, layer_idx=i, expand_k=cfg.expand_k, expand_v=cfg.expand_v, use_short_conv=cfg.use_short_conv, ffn_expansion_factor=cfg.ffn_expansion_factor, ) def crossatt_block(i): if i in cfg.crossatt_layer_idx: return CrossAttention( dim=cfg.dim, num_heads=cfg.crossatt_num_heads, dropout=cfg.crossatt_dropout, layer_idx=i, ) else: return None self.decoder_layers = nn.ModuleList( [ DecoderBlockWithOptionalCrossAttention( simple_gla_block(i), crossatt_block(i), ) for i in range(cfg.num_layers) ] ) def forward( self, encoder_output: torch.Tensor, decoder_input: torch.Tensor, crossatt_mask: torch.Tensor | list[torch.Tensor] | None = None, text_ids: torch.Tensor | None = None, cache: FLACache | None = None, ): x = decoder_input text_freqs = None for layer in self.decoder_layers: x = maybe_grad_ckpt(layer)( x, encoder_output, text_freqs=text_freqs, cache=cache, crossatt_mask=crossatt_mask, ) return x def init_cache(self, max_seq_len, device): return FLACache(num_states=len(self.decoder_layers) + 1) def init_initial_state(self, batch_size=1, scale=1e-2, device="cpu"): return tuple( nn.Parameter( torch.randn( batch_size, self.num_heads, self.head_dim, self.head_dim, device=device, ) * scale ) for _ in range(len(self.decoder_layers)) ) def init_initial_state_lora(self, lora:int=1, batch_size: int = 1, scale: float=1e-2, device: str="cpu"): return tuple( ( nn.Parameter( torch.randn( batch_size, self.num_heads, self.head_dim, lora, device=device, ) * scale ), nn.Parameter( torch.randn( batch_size, self.num_heads, lora, self.head_dim, device=device, ) * scale ) ) for _ in range(len(self.decoder_layers)) ) def _get_query(self, audio_inputs: torch.Tensor, layer_idx: int): assert self.decoder_layers[layer_idx].crossatt is not None x = audio_inputs for _, layer in zip(range(layer_idx - 1), self.decoder_layers): x = layer(x, None) return self.decoder_layers[layer_idx].crossatt._query(x) def forward_first_n_layers( self, encoder_output: torch.Tensor, decoder_input: torch.Tensor, n_first_layers: int, crossatt_mask: torch.Tensor | None = None, cache: FLACache | None = None, ): x = decoder_input if self.text_freqs_embd is not None: text_freqs = torch.arange(encoder_output.shape[1], device=x.device)[None, :] text_freqs = self.text_freqs_embd(text_freqs) else: text_freqs = None for layer in self.decoder_layers[:n_first_layers]: x = maybe_grad_ckpt(layer)( x, encoder_output, text_freqs=text_freqs, cache=cache, crossatt_mask=crossatt_mask, ) return x def prefill( self, encoder_output: torch.Tensor, decoder_input: torch.Tensor, crossatt_mask: torch.Tensor | None = None, cache: FLACache | None = None, ): return self(encoder_output, decoder_input, cache=cache, crossatt_mask=crossatt_mask) def decode_one( self, encoder_output: torch.Tensor, decoder_input: torch.Tensor, cache: Cache, text_freqs: torch.Tensor | None = None, crossatt_mask: torch.Tensor | None = None, ): x = decoder_input for layer in self.decoder_layers: x = layer( x, encoder_output, text_freqs=text_freqs, cache=cache, crossatt_mask=crossatt_mask, ) return x