Musci-ASR-2.4B / processing_Musci.py
Musci-research's picture
upload Musci-ASR-2.4B
6cb6a8a verified
from dataclasses import dataclass
from typing import Optional, Union, List
import torch
import numpy as np
from transformers import BatchEncoding
from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor
import importlib.util
import sys
@dataclass
class MelConfig:
mel_sr: int = 16000
mel_dim: int = 80
mel_n_fft: int = 640
mel_hop_length: int = 160
mel_dtype: torch.dtype = torch.bfloat16
def load_chat_template(template_path: str, package_path: Optional[str] = None) -> List:
"""Dynamically import a chat template module by file path and return its `chat_template`."""
import os
if package_path and package_path not in sys.path:
sys.path.insert(0, package_path)
spec = importlib.util.spec_from_file_location("chat_template_module", template_path)
module = importlib.util.module_from_spec(spec)
sys.modules["chat_template_module"] = module
spec.loader.exec_module(module)
return module.chat_template
class MusciProcessor:
"""Audio processor for Musci ASR: mel-spectrogram + chat-template-driven token layout."""
def __init__(
self,
tokenizer,
config: Optional[MelConfig] = None,
template_path: Optional[str] = None,
enable_time_marker: bool = False,
):
self.tokenizer = tokenizer
self.config = config or MelConfig()
# Whisper log-mel frontend — matches the front-end the model was trained with.
self.feature_extractor = WhisperFeatureExtractor(
feature_size=int(self.config.mel_dim),
sampling_rate=int(self.config.mel_sr),
hop_length=int(self.config.mel_hop_length),
n_fft=int(self.config.mel_n_fft),
)
# Special token ids (Qwen3 tokenizer).
self.start_token_id = 151644
self.end_token_id = 151645
self.audio_start_token_id = 151669
self.audio_end_token_id = 151670
self.audio_placeholder_id = 0
self.chat_template = None if template_path is None else load_chat_template(template_path)
self.enable_time_marker = enable_time_marker
# Digit tokens 0-9 in the Qwen3 tokenizer, used for time markers.
self._digit_token_ids = {str(d): 15 + d for d in range(10)}
self.audio_tokens_per_second = 12.5
self.time_marker_every_seconds = 2
self.time_marker_every_audio_tokens = int(
self.audio_tokens_per_second * self.time_marker_every_seconds
)
def load_template(self, template_path: str):
self.chat_template = load_chat_template(template_path)
print(f"Loaded chat template from {template_path}")
return self
def _get_feat_extract_output_lengths(self, input_lengths):
"""Map raw mel-frame count to number of audio tokens after the encoder downsample."""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = (
((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
)
return output_lengths
def _get_time_marker_token_ids(self, second: int) -> List[int]:
return [self._digit_token_ids[c] for c in str(second)]
def _build_audio_tokens_with_time_markers(self, audio_seq_len: int) -> List[int]:
"""Interleave time markers every `time_marker_every_seconds` seconds of audio tokens."""
num_full_seconds = int(audio_seq_len / self.audio_tokens_per_second)
tokens_list: List[int] = []
audio_tokens_consumed = 0
for second in range(
self.time_marker_every_seconds, num_full_seconds + 1, self.time_marker_every_seconds
):
marker_pos = (
(second // self.time_marker_every_seconds) * self.time_marker_every_audio_tokens
)
segment_len = marker_pos - audio_tokens_consumed
if segment_len > 0:
tokens_list.extend([self.audio_placeholder_id] * segment_len)
audio_tokens_consumed += segment_len
tokens_list.extend(self._get_time_marker_token_ids(second))
remaining = audio_seq_len - audio_tokens_consumed
if remaining > 0:
tokens_list.extend([self.audio_placeholder_id] * remaining)
return tokens_list
def _build_input_from_template(self, num_audio_tokens: int) -> tuple:
"""Walk the loaded chat_template and emit (input_ids, audio_input_mask) for inference."""
if self.chat_template is None:
raise ValueError("Chat template not loaded. Call load_template() first.")
input_ids: List[int] = []
audio_mask: List[bool] = []
for segment in self.chat_template:
seg_type = segment.type
if seg_type == "constant_text_token":
text_ids = segment.text_ids.tolist()
input_ids.extend(text_ids)
audio_mask.extend([False] * len(text_ids))
elif seg_type in ("audio_contiguous", "audio_token"):
if self.enable_time_marker:
audio_ids = self._build_audio_tokens_with_time_markers(num_audio_tokens)
input_ids.extend(audio_ids)
audio_mask.extend(
[tok == self.audio_placeholder_id for tok in audio_ids]
)
else:
input_ids.extend([self.audio_placeholder_id] * num_audio_tokens)
audio_mask.extend([True] * num_audio_tokens)
elif seg_type == "text_token":
# Generation starts here at inference time.
break
return input_ids, audio_mask
def _build_input_legacy(self, num_audio_tokens: int) -> tuple:
"""Hardcoded [start, audio_start, audio*, audio_end] layout, used when no template is loaded."""
if self.enable_time_marker:
audio_ids = self._build_audio_tokens_with_time_markers(num_audio_tokens)
ids = (
[self.start_token_id, self.audio_start_token_id]
+ audio_ids
+ [self.audio_end_token_id]
)
audio_mask = [tok == self.audio_placeholder_id for tok in audio_ids]
mask = [False, False] + audio_mask + [False]
else:
ids = (
[self.start_token_id, self.audio_start_token_id]
+ [self.audio_placeholder_id] * num_audio_tokens
+ [self.audio_end_token_id]
)
mask = [False, False] + [True] * num_audio_tokens + [False]
return ids, mask
def __call__(
self,
audio: Union[np.ndarray, torch.Tensor],
return_tensors: str = "pt",
**kwargs,
):
if audio is None:
raise ValueError("Audio input is required.")
if isinstance(audio, torch.Tensor):
waveform = audio.detach().to(dtype=torch.float32).cpu().numpy()
else:
waveform = np.asarray(audio, dtype=np.float32)
if waveform.ndim == 2:
waveform = waveform[0]
try:
mel = self.feature_extractor._np_extract_fbank_features(
waveform[None, ...], device="cpu"
)[0]
except TypeError:
mel = self.feature_extractor._np_extract_fbank_features(waveform[None, ...])[0]
input_features = torch.from_numpy(mel).to(self.config.mel_dtype)
if input_features.dim() == 3:
input_features = input_features.squeeze(0)
raw_mel_len = input_features.shape[-1]
num_audio_tokens = self._get_feat_extract_output_lengths(raw_mel_len)
if self.chat_template is not None:
ids, mask = self._build_input_from_template(num_audio_tokens)
else:
ids, mask = self._build_input_legacy(num_audio_tokens)
input_ids_tensor = torch.tensor([ids], dtype=torch.long)
audio_mask_tensor = torch.tensor([mask], dtype=torch.bool)
attention_mask_tensor = torch.ones_like(input_ids_tensor)
seq_lens_tensor = torch.tensor([raw_mel_len], dtype=torch.long)
data = {
"input_ids": input_ids_tensor,
"attention_mask": attention_mask_tensor,
"audio_data": input_features,
"audio_data_seqlens": seq_lens_tensor,
"audio_input_mask": audio_mask_tensor,
}
return BatchEncoding(data=data, tensor_type=return_tensors)
def batch_decode(self, *args, **kwargs):
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
return self.tokenizer.decode(*args, **kwargs)