import os import torch import torch.nn.functional as F from einops import rearrange from torch import nn from transformers.cache_utils import DynamicCache from tts.layers.attention import (CrossAttention, SelfAttention, precompute_freqs_cis) from tts.layers.conv import ConvNeXtBlock from tts.layers.ffn import SwiGLU from tts.model.cache_utils import FLACache, TransformerDecoderCache from tts.model.config import (ConvFormerEncoderConfig, TransformerDecoderConfig, TransformerEncoderConfig) from tts.model.registry import register_decoder, register_encoder from tts.model.shortconv import ShortConvBlock if "GRAD_CKPT" in os.environ: def maybe_grad_ckpt(f): def grad_ckpt_f(*args, **kwargs): return torch.utils.checkpoint.checkpoint( f, *args, **kwargs, use_reentrant=False ) return grad_ckpt_f else: def maybe_grad_ckpt(f): return f class TransformerBlock(nn.Module): def __init__( self, dim: int, num_heads: int, layer_idx: int, ffn_expansion_factor: int, is_causal: bool, ): super().__init__() self.tmix = SelfAttention(dim, num_heads, layer_idx, is_causal=is_causal) self.cmix = SwiGLU(dim, ffn_expansion_factor) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) def forward( self, x, freqs: torch.Tensor | None = None, cache: DynamicCache | None = None, mask: torch.Tensor | None = None, ): x = self.tmix(self.norm1(x), freqs=freqs, cache=cache, mask=mask) + x x = self.cmix(self.norm2(x)) + x return x class DecoderBlockWithOptionalCrossAttention(nn.Module): def __init__(self, decoder_block: nn.Module, crossatt: nn.Module | None = None): super().__init__() self.decoder_block = decoder_block self.crossatt = crossatt def forward( self, x: torch.Tensor, encoder_output: torch.Tensor | None = None, freqs: torch.Tensor | None = None, cache: FLACache | None = None, selfatt_mask: torch.Tensor | None = None, crossatt_mask: torch.Tensor | None = None, ) -> torch.Tensor: x = self.decoder_block( x, freqs=freqs, cache=cache, ) if self.crossatt is not None: x = x + self.crossatt( x, k=encoder_output, mask=crossatt_mask, cache=cache, ) return x @register_decoder("sa_transformer") class TransformerDecoder(nn.Module): config = TransformerDecoderConfig def __init__(self, cfg: TransformerDecoderConfig): super().__init__() assert cfg.dim % cfg.num_heads == 0 self.head_dim = cfg.dim // cfg.num_heads def transformer_block(i): conv_layers = [] if cfg.conv_layers is None else cfg.conv_layers if i in conv_layers: return ShortConvBlock( dim=cfg.dim, kernel_size=4, ffn_expansion_factor=cfg.ffn_expansion_factor, layer_idx=i, use_fast_conv1d=True, ) else: return TransformerBlock( dim=cfg.dim, num_heads=cfg.num_heads, layer_idx=i, ffn_expansion_factor=cfg.ffn_expansion_factor, is_causal=True, ) def crossatt_block(i): return CrossAttention( dim=cfg.dim, num_heads=cfg.crossatt_num_heads, dropout=cfg.crossatt_dropout, layer_idx=i, ) self.decoder_layers = nn.ModuleList( [ DecoderBlockWithOptionalCrossAttention( transformer_block(i), crossatt_block(i) if i in cfg.crossatt_layer_idx else None, ) for i in range(cfg.num_layers) ] ) def forward( self, encoder_output: torch.Tensor, decoder_input: torch.Tensor, crossatt_mask: torch.Tensor | None = None, cache: FLACache | None = None, ): x = decoder_input positions = torch.arange(x.shape[1], device=x.device) freqs = precompute_freqs_cis(positions, self.head_dim) for layer in self.decoder_layers: x = maybe_grad_ckpt(layer)( x, encoder_output, freqs=freqs, crossatt_mask=crossatt_mask, cache=cache, ) return x def init_cache(self, max_seq_len, device): return FLACache(head_dim=self.head_dim, max_seq_len=max_seq_len, device=device) def prefill( self, encoder_output: torch.Tensor, decoder_input: torch.Tensor, cache: FLACache | None = None, ): return self(encoder_output, decoder_input, cache=cache) def decode_one( self, encoder_output: torch.Tensor, decoder_input: torch.Tensor, cache: FLACache, crossatt_mask: torch.Tensor | None = None, ): x = decoder_input pos = cache._seen_tokens freq = cache.freqs[:, [pos]] for layer in self.decoder_layers: x = layer( x, encoder_output, freqs=freq, cache=cache, crossatt_mask=crossatt_mask, ) return x @register_encoder("sa_transformer") class TransformerEncoder(nn.Module): config = TransformerEncoderConfig def __init__(self, cfg: TransformerEncoderConfig): super().__init__() assert cfg.dim % cfg.num_heads == 0 self.head_dim = cfg.dim // cfg.num_heads self.encoder_layers = nn.ModuleList( [ TransformerBlock( cfg.dim, cfg.num_heads, i, cfg.ffn_expansion_factor, is_causal=False, ) for i in range(cfg.num_layers) ] ) def forward( self, x: torch.Tensor, mask: torch.Tensor | None = None, ): positions = torch.arange(x.shape[1], device=x.device) freqs = precompute_freqs_cis(positions, self.head_dim) if mask is not None: mask = rearrange(mask, "b n m -> b 1 n m") mask = torch.logical_or( mask, rearrange(torch.eye(mask.shape[-1], device=x.device), "n m -> 1 1 n m"), ) for layer in self.encoder_layers: x = layer(x, freqs=freqs, mask=mask) return x @register_encoder("convformer_encoder") class ConvFormerEncoder(nn.Module): config = ConvFormerEncoderConfig def __init__(self, cfg: ConvFormerEncoderConfig): super().__init__() assert cfg.dim % cfg.num_heads == 0 self.head_dim = cfg.dim // cfg.num_heads self.conv_layers = nn.ModuleList( [ConvNeXtBlock(cfg.dim) for _ in range(cfg.num_conv_layers)] ) self.encoder_layers = nn.ModuleList( [ TransformerBlock( cfg.dim, cfg.num_heads, i, cfg.ffn_expansion_factor, is_causal=False, ) for i in range(cfg.num_transformer_layers) ] ) def forward( self, x: torch.Tensor, mask: torch.Tensor | None = None, text_rel_pos: torch.Tensor | None = None, ): if text_rel_pos is None: text_rel_pos = torch.arange(x.shape[1], device=x.device) freqs = precompute_freqs_cis(text_rel_pos, self.head_dim) else: freqs = precompute_freqs_cis(text_rel_pos, self.head_dim).unsqueeze(1) x = x.transpose(1, 2) for layer in self.conv_layers: x = layer(x) x = x.transpose(1, 2) if mask is not None: mask = rearrange(mask, "b n m -> b 1 n m") mask = torch.logical_or( mask, rearrange(torch.eye(mask.shape[-1], device=x.device), "n m -> 1 1 n m"), ) for layer in self.encoder_layers: x = layer(x, freqs=freqs, mask=mask) return x