pardi-speech / tts /model /shortconv.py
Mehdi Lakbar
Initial demo of Lina-speech (pardi-speech)
56cfa73
import torch
from fla.modules import ShortConvolution
from torch import nn
from tts.layers.ffn import SwiGLU
from tts.model.cache_utils import FLACache
class ShortConvBlock(nn.Module):
def __init__(
self,
dim: int,
kernel_size: int,
ffn_expansion_factor: int,
layer_idx: int,
use_fast_conv1d: bool = True,
):
super().__init__()
self.tmix = ShortConvolution(dim, kernel_size, use_fast_conv1d=use_fast_conv1d)
self.cmix = SwiGLU(dim, ffn_expansion_factor)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.layer_idx = layer_idx
def forward(self, x: torch.Tensor, cache: FLACache | None = None, **kwargs):
last_state = None
if cache is not None and len(cache) > self.layer_idx:
last_state = cache[self.layer_idx]["short_conv_state"]
x, short_conv_state = self.tmix(
self.norm1(x), cache=last_state, output_final_state=cache is not None
)
if cache is not None:
cache.update(
short_conv_state=short_conv_state,
layer_idx=self.layer_idx,
offset=x.shape[1],
)
x = self.cmix(self.norm2(x))
return x