Audio-to-Audio
Transformers
Safetensors
dashengtokenizer
feature-extraction
audio-classification
signal-processing
custom_code
Instructions to use mispeech/dashengtokenizer with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use mispeech/dashengtokenizer with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("mispeech/dashengtokenizer", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from .configuration_dasheng_tokenizer import DashengTokenizerConfig | |
| from .modeling_dasheng_encoder import DashengEncoder | |
| from .vocos import VocosModel | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| import torchaudio | |
| from transformers import PreTrainedModel | |
| class VocosMelSpec(torch.nn.Module): | |
| """MelSpectrogram frontend for Vocos.""" | |
| def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256, n_mels=100, padding="center"): | |
| super().__init__() | |
| if padding not in ["center", "same"]: | |
| raise ValueError("Padding must be 'center' or 'same'.") | |
| self.padding = padding | |
| self.sample_rate = sample_rate | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.n_mels = n_mels | |
| with torch.device("cpu"): | |
| self.mel_spec = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=self.sample_rate, | |
| n_fft=self.n_fft, | |
| hop_length=self.hop_length, | |
| n_mels=self.n_mels, | |
| center=self.padding == "center", | |
| power=1,) | |
| def forward(self, audio, **kwargs): | |
| if self.padding == "same": | |
| pad = self.mel_spec.win_length - self.mel_spec.hop_length | |
| audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") | |
| mel = self.mel_spec(audio) | |
| return torch.log(torch.clip(mel, min=1e-7)) | |
| class DashengTokenizerEncoder(torch.nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int = 1280, | |
| depth:int = 32, | |
| num_heads: int = 16, | |
| n_mels_patch: int = 128, | |
| hop_length: int = 160, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.model = DashengEncoder(embed_dim=embed_dim, depth=depth, num_heads=num_heads) | |
| self.embed_dim = int(self.model.embed_dim) | |
| self.model.outputlayer = torch.nn.Identity() | |
| self.front_end = VocosMelSpec(hop_length=hop_length, n_mels=n_mels_patch) | |
| self.patch_embed = torch.nn.Conv2d( | |
| 1, self.model.embed_dim, (n_mels_patch, 4), (n_mels_patch, 4) | |
| ) | |
| self.norm = torch.nn.LayerNorm(self.model.embed_dim) | |
| # Store parameters for reference | |
| self.n_fft = self.model.front_end.n_fft | |
| self.hop_size = self.model.front_end.hop_size | |
| def forward( | |
| self, | |
| input: torch.Tensor, | |
| input_attn_mask: torch.Tensor | None = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass of the encoder. | |
| Args: | |
| input: Audio tensor of shape (batch_size, num_samples) | |
| input_attn_mask: Optional attention mask | |
| Returns: | |
| Combined embeddings of shape (batch_size, num_tokens, embed_dim) | |
| """ | |
| with torch.no_grad(): | |
| semantic_emb = self.model(input, input_attn_mask) | |
| # acoustic part | |
| mel = self.front_end(input).unsqueeze(1) | |
| mel_emb = self.patch_embed(mel) | |
| acoustic_emb = rearrange(mel_emb, "b c f t -> b (f t) c") | |
| acoustic_emb = self.norm(acoustic_emb) | |
| semantic_emb = semantic_emb[:, : acoustic_emb.shape[1], :] | |
| emb = semantic_emb + acoustic_emb | |
| return emb | |
| class DashengTokenizerPreTrainedModel(PreTrainedModel): | |
| config_class = DashengTokenizerConfig | |
| supports_gradient_checkpointing = True | |
| class DashengTokenizerModel(DashengTokenizerPreTrainedModel): | |
| """ | |
| HuggingFace-compatible DashEng Tokenizer Model (Encoder + Decoder). | |
| This model includes both the encoder and decoder for end-to-end audio processing. | |
| """ | |
| def __init__(self, config: DashengTokenizerConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.encoder = DashengTokenizerEncoder( | |
| embed_dim=config.embed_dim, | |
| depth = config.depth, | |
| num_heads=config.num_heads, | |
| n_mels_patch=config.n_mels_patch, | |
| hop_length=config.hop_length, | |
| ) | |
| self.embed_dim = self.encoder.embed_dim | |
| # Upsampler (if needed) | |
| self.upsampler = None | |
| if config.upsample_tokens > 1: | |
| self.upsampler = torch.nn.ConvTranspose1d( | |
| self.embed_dim, self.embed_dim, | |
| kernel_size=config.upsample_tokens, | |
| stride=config.upsample_tokens | |
| ) | |
| # Decoder | |
| self.decoder = VocosModel( | |
| input_channels=self.embed_dim, | |
| hidden_dim=config.decoder_embed_dim, | |
| intermediate_dim=config.decoder_intermediate_size, | |
| vocos_istft_hop=config.istft_hop, | |
| vocos_n_fft=config.istft_n_fft, | |
| num_layers=config.decoder_depth, | |
| ) | |
| self.post_init() | |
| def encode( | |
| self, | |
| audio: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Encode audio into embeddings.""" | |
| return self.encoder(audio, attention_mask) | |
| def decode(self, embeddings: torch.Tensor) -> torch.Tensor: | |
| """Decode embeddings back to audio.""" | |
| if self.upsampler is not None: | |
| embeddings = self.upsampler(embeddings.transpose(-2, -1)).transpose(-2, -1) | |
| output = self.decoder(embeddings.transpose(-2, -1)) | |
| return output | |
| def forward( | |
| self, | |
| audio: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple[torch.Tensor], dict]: | |
| """ | |
| Forward pass of the DashEng tokenizer. | |
| Args: | |
| audio: Audio tensor of shape (batch_size, num_samples) | |
| attention_mask: Optional attention mask | |
| output_attentions: Whether to return attention weights | |
| output_hidden_states: Whether to return hidden states | |
| return_dict: Whether to return a dict | |
| Returns: | |
| Reconstructed audio of shape (batch_size, num_samples) | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| # Encode | |
| embeddings = self.encoder(audio, attention_mask) | |
| # Decode | |
| audio_reconstructed = self.decode(embeddings) | |
| if not return_dict: | |
| return (audio_reconstructed,) | |
| return { | |
| "audio": audio_reconstructed, | |
| "embeddings": embeddings, | |
| } | |