import math import onnxruntime import numpy as np import base64 import whisper import re import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from typing import List, Any, Dict from transformers import Wav2Vec2CTCTokenizer, PreTrainedModel, PretrainedConfig import pycantonese def parse_jyutping(jyutping: str) -> str: """Helper to parse Jyutping string using pycantonese.""" # Move the tone number to the end if it's not already there if jyutping and not jyutping[-1].isdigit(): match = re.search(r"([1-6])", jyutping) if match: tone = match.group(1) jyutping = jyutping.replace(tone, "") + tone try: # Ensure pycantonese is installed and working parsed_jyutping = pycantonese.parse_jyutping(jyutping)[0] onset = parsed_jyutping.onset if parsed_jyutping.onset else "" nucleus = parsed_jyutping.nucleus if parsed_jyutping.nucleus else "" coda = parsed_jyutping.coda if parsed_jyutping.coda else "" tone_val = str(parsed_jyutping.tone) if parsed_jyutping.tone else "" # Construct the phoneme string, e.g., onset + nucleus + coda + tone # This depends on the exact format your CTC model expects return "".join([onset, nucleus, coda, tone_val]) # Simplified example except Exception as e: print(f"Failed to parse Jyutping '{jyutping}': {e}. Returning original.") return jyutping class CTCTransformerConfig(PretrainedConfig): def __init__( self, vocab_size=100, # number of unique speech tokens num_labels=50, # number of phoneme IDs (+1 for blank) eos_token_id=2, bos_token_id=1, pad_token_id=0, blank_id=0, # blank token id for CTC decoding hidden_size=384, num_hidden_layers=50, num_attention_heads=4, intermediate_size=2048, dropout=0.1, max_position_embeddings=1024, ctc_loss_reduction="mean", ctc_zero_infinity=True, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.num_labels = num_labels self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.max_position_embeddings = max_position_embeddings self.dropout = dropout self.eos_token_id = eos_token_id self.bos_token_id = bos_token_id self.pad_token_id = pad_token_id self.blank_id = blank_id self.ctc_loss_reduction = ctc_loss_reduction self.ctc_zero_infinity = ctc_zero_infinity class SinusoidalPositionEncoder(torch.nn.Module): """Sinusoidal positional embeddings for sequences""" def __init__(self, d_model=384, dropout_rate=0.1): super().__init__() self.d_model = d_model self.dropout = nn.Dropout(p=dropout_rate) def encode( self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32, ): if depth is None: depth = self.d_model batch_size = positions.size(0) positions = positions.type(dtype) device = positions.device # Handle even depth depth_float = float(depth) log_timescale_increment = torch.log( torch.tensor([10000.0], dtype=dtype, device=device) ) / (depth_float / 2.0 - 1.0) # Create position encodings inv_timescales = torch.exp( torch.arange(depth_float // 2, device=device, dtype=dtype) * (-log_timescale_increment) ) # Create correct shapes for broadcasting pos_seq = positions.view(-1, 1) # [batch_size*seq_len, 1] inv_timescales = inv_timescales.view(1, -1) # [1, depth//2] scaled_time = pos_seq * inv_timescales # [batch_size*seq_len, depth//2] # Apply sin and cos sin_encodings = torch.sin(scaled_time) cos_encodings = torch.cos(scaled_time) # Interleave sin and cos or concatenate pos_encodings = torch.zeros( positions.shape[0], positions.shape[1], depth, device=device, dtype=dtype ) even_indices = torch.arange(0, depth, 2, device=device) odd_indices = torch.arange(1, depth, 2, device=device) pos_encodings[:, :, even_indices] = sin_encodings.view( batch_size, -1, depth // 2 ) pos_encodings[:, :, odd_indices] = cos_encodings.view( batch_size, -1, depth // 2 ) return pos_encodings def forward(self, x): batch_size, timesteps, input_dim = x.size() # Create position indices [1, 2, ..., timesteps] positions = ( torch.arange(1, timesteps + 1, device=x.device) .unsqueeze(0) .expand(batch_size, -1) ) position_encoding = self.encode(positions, input_dim, x.dtype) # Apply dropout to the sum return self.dropout(x + position_encoding) class CTCTransformerModel(PreTrainedModel): config_class = CTCTransformerConfig def __init__(self, config): super().__init__(config) self.embed = nn.Embedding( config.vocab_size + 1, config.hidden_size, padding_idx=config.vocab_size, ) encoder_layer = nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_attention_heads, dim_feedforward=config.intermediate_size, dropout=self.config.dropout, activation="gelu", batch_first=True, ) self.encoder = nn.TransformerEncoder( encoder_layer, num_layers=config.num_hidden_layers ) self.pos_embed = SinusoidalPositionEncoder( d_model=config.hidden_size, dropout_rate=config.dropout ) self.norm = nn.LayerNorm(config.hidden_size) self.classifier = nn.Linear(config.hidden_size, config.num_labels) def forward( self, input_ids, attention_mask=None, labels=None, ): # Embed the input tokens x = self.embed(input_ids) x = self.norm(x) # Add positional embeddings x = self.pos_embed(x) # Create mask for transformer if attention_mask is not None: # PyTorch transformer expects mask where True indicates positions to be MASKED (padding) # Transformers attention_mask uses: # - 1 for tokens that are NOT MASKED (should be attended to) # - 0 for tokens that ARE MASKED (padding) # So, we need to invert the attention_mask to match PyTorch Transformer's expectation src_key_padding_mask = attention_mask == 0 else: src_key_padding_mask = None # Pass through encoder with proper masking x = self.encoder(x, src_key_padding_mask=src_key_padding_mask) x = self.norm(x) # Project to output labels logits = self.classifier(x) # [B, T, num_labels] loss = None if labels is not None: input_lengths = attention_mask.sum(-1) # assuming that padded tokens are filled with -100 # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) # ctc_loss doesn't support fp16 log_probs = nn.functional.log_softmax( logits, dim=-1, dtype=torch.float32 ).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): loss = nn.functional.ctc_loss( log_probs, flattened_targets, input_lengths, target_lengths, blank=0, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) return {"loss": loss, "logits": logits} @torch.inference_mode() def predict(self, input_ids: List[int]): blank_id = self.config.blank_id # Create attention mask with 1s (not masked) for all positions attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to( input_ids.device ) with torch.no_grad(): x = self.embed(input_ids) x = self.pos_embed(x) # Add positional embeddings # Using the same masking convention as forward method encoded = self.encoder(x, src_key_padding_mask=(attention_mask == 0)) logits = self.classifier(encoded) # [1, T, V] log_probs = F.log_softmax(logits, dim=-1) # [1, T, V] pred_ids = torch.argmax(log_probs, dim=-1).squeeze(0).tolist() # Greedy decode with collapse pred_phoneme_ids = [] prev = None for idx in pred_ids: if idx != blank_id and idx != prev: pred_phoneme_ids.append(idx) prev = idx return pred_phoneme_ids def load_speech_tokenizer(speech_tokenizer_path: str): """Load speech tokenizer ONNX model.""" option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 session = onnxruntime.InferenceSession( speech_tokenizer_path, sess_options=option, providers=["CPUExecutionProvider"], ) return session def extract_speech_token(audio, speech_tokenizer_session): """ Extract speech tokens from audio using speech tokenizer. Args: audio: audio signal (torch.Tensor or numpy.ndarray), shape (T,) at 16kHz speech_tokenizer_session: ONNX speech tokenizer session Returns: speech_token: tensor of shape (1, num_tokens) speech_token_len: tensor of shape (1,) with token sequence length """ # Ensure audio is on CPU for processing if isinstance(audio, torch.Tensor): audio = audio.cpu().numpy() elif isinstance(audio, np.ndarray): pass else: raise ValueError("Audio must be torch.Tensor or numpy.ndarray") # Convert to torch tensor for mel-spectrogram audio_tensor = torch.from_numpy(audio).float().unsqueeze(0) # Extract mel-spectrogram (whisper format) feat = whisper.log_mel_spectrogram(audio_tensor, n_mels=128) # Run speech tokenizer speech_token = ( speech_tokenizer_session.run( None, { speech_tokenizer_session.get_inputs()[0] .name: feat.detach() .cpu() .numpy(), speech_tokenizer_session.get_inputs()[1].name: np.array( [feat.shape[2]], dtype=np.int32 ), }, )[0] .flatten() .tolist() ) speech_token = torch.tensor([speech_token], dtype=torch.int32) speech_token_len = torch.tensor([len(speech_token[0])], dtype=torch.int32) return speech_token, speech_token_len class EndpointHandler: def __init__(self, model_dir: str, **kwargs: Any): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.speech_tokenizer_session = load_speech_tokenizer( f"{model_dir}/speech_tokenizer_v2.onnx" ) self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_dir) self.model = ( CTCTransformerModel.from_pretrained( model_dir, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, ) .eval() .to(device) ) def preprocess(self, inputs): waveform, original_sampling_rate = torchaudio.load(inputs) if original_sampling_rate != 16000: resampler = torchaudio.transforms.Resample( orig_freq=original_sampling_rate, new_freq=16000 ) audio_array = resampler(waveform).numpy().flatten() else: audio_array = waveform.numpy().flatten() return audio_array def __call__(self, data: Dict[str, Any]) -> List[str]: # get inputs, assuming a base64 encoded wav file inputs = data.pop("inputs", data) # decode base64 file and save to temp file audio = inputs["audio"] audio_bytes = base64.b64decode(audio) temp_wav_path = "/tmp/temp.wav" with open(temp_wav_path, "wb") as f: f.write(audio_bytes) audio_array = self.preprocess(temp_wav_path) # Extract speech tokens speech_token, speech_token_len = extract_speech_token( audio_array, self.speech_tokenizer_session ) with torch.no_grad(): speech_token = speech_token.to(next(self.model.parameters()).device) outputs = self.model.predict(speech_token) transcription = self.tokenizer.decode(outputs, skip_special_tokens=True) print(transcription) transcription = " ".join( [parse_jyutping(jyt) for jyt in transcription.split(" ")] ) return {"transcription": transcription}