from dataclasses import dataclass, field from typing import Literal @dataclass class SimpleGLADecoderConfig: name: str = "simple_gla" dim: int = 512 use_short_conv = True expand_k: float = 0.5 expand_v: float = 1.0 num_heads: int = 4 num_layers: int = 12 ffn_expansion_factor: int = 4 conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10]) blind_crossatt: bool = True listen_read_crossatt: dict[int, list[Literal["listen", "read"]]] | None = None crossatt_num_heads: int = 1 crossatt_dropout: float = 0.1 crossatt_layer_idx: list[int] = field(default_factory=lambda: [5]) @dataclass class TransformerEncoderConfig: name: str = "sa_transformer" dim: int = 512 num_heads: int = 8 num_layers: int = 6 ffn_expansion_factor: int = 4 @dataclass class ConvFormerEncoderConfig: name: str = "convformer_encoder" dim: int = 512 num_heads: int = 8 num_transformer_layers: int = 6 num_conv_layers: int = 3 ffn_expansion_factor: int = 4 @dataclass class TransformerDecoderConfig: name: str = "sa_transformer" dim: int = 512 num_heads: int = 8 num_layers: int = 12 conv_layers: list[int] | None = field(default_factory=lambda: [0, 2, 4, 6, 8, 10]) ffn_expansion_factor: int = 4 crossatt_dropout: float = 0.1 crossatt_num_heads: int = 1 crossatt_layer_idx: list[int] = field(default_factory=lambda: [5, 6]) DecoderConfig = TransformerDecoderConfig | SimpleGLADecoderConfig EncoderConfig = TransformerEncoderConfig | ConvFormerEncoderConfig @dataclass class TTSConfig: dim: int = 512 text_vocab_size: int = 256 + 3 audio_vocab_size: int = 4096 + 3 audio_embed_size: int = 16 audio_input_type: Literal["continuous", "discrete"] = "continuous" diffusion_head_num_layers: int = 3 encoder_cfg: EncoderConfig = field(default_factory=TransformerEncoderConfig) decoder_cfg: DecoderConfig = field(default_factory=TransformerDecoderConfig) stop_prediction_head: bool = True multi_stop_prediction_head: bool = False multi_stop_num_tokens: int = 8 multi_stop_padding_idx: int | None = None multi_stop_tie_embd: bool = True stop_token_embd: bool = False text_stop_token_embd: bool = False continuous_diffusion: bool = True num_sink_tokens: int = 0 disabled_crossatt_head_idx: list[tuple[int, int]] | None = None patchvae_path: str | None = None text_tokenizer_path: str | None = None eos: bool = False bos: bool = False @dataclass class PlayHeadConfig: selected_cross_attention_heads: list[tuple[int, int]] dim: int = 256 num_layers: int = 6 num_frame_lag: int = 2 num_sink_tokens: int = 4 cycle_len: int = 8 logits_head: bool = True target_lag: int = 0 avg_pool_stride: int = 3 circular_head: bool = False @dataclass class QueryVCConfig: dim: int = 512 semantic_dim: int = 512 num_layers: int = 12 lag: int = 4 audio_embed_size: int = 8 diffusion_head_num_layers: int = 3