File size: 1,272 Bytes
56cfa73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from fla.modules import ShortConvolution
from torch import nn

from tts.layers.ffn import SwiGLU
from tts.model.cache_utils import FLACache


class ShortConvBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        kernel_size: int,
        ffn_expansion_factor: int,
        layer_idx: int,
        use_fast_conv1d: bool = True,
    ):
        super().__init__()
        self.tmix = ShortConvolution(dim, kernel_size, use_fast_conv1d=use_fast_conv1d)
        self.cmix = SwiGLU(dim, ffn_expansion_factor)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.layer_idx = layer_idx

    def forward(self, x: torch.Tensor, cache: FLACache | None = None, **kwargs):
        last_state = None
        if cache is not None and len(cache) > self.layer_idx:
            last_state = cache[self.layer_idx]["short_conv_state"]
        x, short_conv_state = self.tmix(
            self.norm1(x), cache=last_state, output_final_state=cache is not None
        )
        if cache is not None:
            cache.update(
                short_conv_state=short_conv_state,
                layer_idx=self.layer_idx,
                offset=x.shape[1],
            )
        x = self.cmix(self.norm2(x))
        return x