import torch from fla.modules import ShortConvolution from torch import nn from tts.layers.ffn import SwiGLU from tts.model.cache_utils import FLACache class ShortConvBlock(nn.Module): def __init__( self, dim: int, kernel_size: int, ffn_expansion_factor: int, layer_idx: int, use_fast_conv1d: bool = True, ): super().__init__() self.tmix = ShortConvolution(dim, kernel_size, use_fast_conv1d=use_fast_conv1d) self.cmix = SwiGLU(dim, ffn_expansion_factor) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.layer_idx = layer_idx def forward(self, x: torch.Tensor, cache: FLACache | None = None, **kwargs): last_state = None if cache is not None and len(cache) > self.layer_idx: last_state = cache[self.layer_idx]["short_conv_state"] x, short_conv_state = self.tmix( self.norm1(x), cache=last_state, output_final_state=cache is not None ) if cache is not None: cache.update( short_conv_state=short_conv_state, layer_idx=self.layer_idx, offset=x.shape[1], ) x = self.cmix(self.norm2(x)) return x