Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from fla.modules import ShortConvolution | |
| from torch import nn | |
| from tts.layers.ffn import SwiGLU | |
| from tts.model.cache_utils import FLACache | |
| class ShortConvBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| kernel_size: int, | |
| ffn_expansion_factor: int, | |
| layer_idx: int, | |
| use_fast_conv1d: bool = True, | |
| ): | |
| super().__init__() | |
| self.tmix = ShortConvolution(dim, kernel_size, use_fast_conv1d=use_fast_conv1d) | |
| self.cmix = SwiGLU(dim, ffn_expansion_factor) | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.layer_idx = layer_idx | |
| def forward(self, x: torch.Tensor, cache: FLACache | None = None, **kwargs): | |
| last_state = None | |
| if cache is not None and len(cache) > self.layer_idx: | |
| last_state = cache[self.layer_idx]["short_conv_state"] | |
| x, short_conv_state = self.tmix( | |
| self.norm1(x), cache=last_state, output_final_state=cache is not None | |
| ) | |
| if cache is not None: | |
| cache.update( | |
| short_conv_state=short_conv_state, | |
| layer_idx=self.layer_idx, | |
| offset=x.shape[1], | |
| ) | |
| x = self.cmix(self.norm2(x)) | |
| return x | |