speech2phone-ctc / handler.py
indiejoseph's picture
Update handler.py
70fb048 verified
import math
import onnxruntime
import numpy as np
import base64
import whisper
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from typing import List, Any, Dict
from transformers import Wav2Vec2CTCTokenizer, PreTrainedModel, PretrainedConfig
import pycantonese
def parse_jyutping(jyutping: str) -> str:
"""Helper to parse Jyutping string using pycantonese."""
# Move the tone number to the end if it's not already there
if jyutping and not jyutping[-1].isdigit():
match = re.search(r"([1-6])", jyutping)
if match:
tone = match.group(1)
jyutping = jyutping.replace(tone, "") + tone
try:
# Ensure pycantonese is installed and working
parsed_jyutping = pycantonese.parse_jyutping(jyutping)[0]
onset = parsed_jyutping.onset if parsed_jyutping.onset else ""
nucleus = parsed_jyutping.nucleus if parsed_jyutping.nucleus else ""
coda = parsed_jyutping.coda if parsed_jyutping.coda else ""
tone_val = str(parsed_jyutping.tone) if parsed_jyutping.tone else ""
# Construct the phoneme string, e.g., onset + nucleus + coda + tone
# This depends on the exact format your CTC model expects
return "".join([onset, nucleus, coda, tone_val]) # Simplified example
except Exception as e:
print(f"Failed to parse Jyutping '{jyutping}': {e}. Returning original.")
return jyutping
class CTCTransformerConfig(PretrainedConfig):
def __init__(
self,
vocab_size=100, # number of unique speech tokens
num_labels=50, # number of phoneme IDs (+1 for blank)
eos_token_id=2,
bos_token_id=1,
pad_token_id=0,
blank_id=0, # blank token id for CTC decoding
hidden_size=384,
num_hidden_layers=50,
num_attention_heads=4,
intermediate_size=2048,
dropout=0.1,
max_position_embeddings=1024,
ctc_loss_reduction="mean",
ctc_zero_infinity=True,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.num_labels = num_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.dropout = dropout
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.blank_id = blank_id
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity
class SinusoidalPositionEncoder(torch.nn.Module):
"""Sinusoidal positional embeddings for sequences"""
def __init__(self, d_model=384, dropout_rate=0.1):
super().__init__()
self.d_model = d_model
self.dropout = nn.Dropout(p=dropout_rate)
def encode(
self,
positions: torch.Tensor = None,
depth: int = None,
dtype: torch.dtype = torch.float32,
):
if depth is None:
depth = self.d_model
batch_size = positions.size(0)
positions = positions.type(dtype)
device = positions.device
# Handle even depth
depth_float = float(depth)
log_timescale_increment = torch.log(
torch.tensor([10000.0], dtype=dtype, device=device)
) / (depth_float / 2.0 - 1.0)
# Create position encodings
inv_timescales = torch.exp(
torch.arange(depth_float // 2, device=device, dtype=dtype)
* (-log_timescale_increment)
)
# Create correct shapes for broadcasting
pos_seq = positions.view(-1, 1) # [batch_size*seq_len, 1]
inv_timescales = inv_timescales.view(1, -1) # [1, depth//2]
scaled_time = pos_seq * inv_timescales # [batch_size*seq_len, depth//2]
# Apply sin and cos
sin_encodings = torch.sin(scaled_time)
cos_encodings = torch.cos(scaled_time)
# Interleave sin and cos or concatenate
pos_encodings = torch.zeros(
positions.shape[0], positions.shape[1], depth, device=device, dtype=dtype
)
even_indices = torch.arange(0, depth, 2, device=device)
odd_indices = torch.arange(1, depth, 2, device=device)
pos_encodings[:, :, even_indices] = sin_encodings.view(
batch_size, -1, depth // 2
)
pos_encodings[:, :, odd_indices] = cos_encodings.view(
batch_size, -1, depth // 2
)
return pos_encodings
def forward(self, x):
batch_size, timesteps, input_dim = x.size()
# Create position indices [1, 2, ..., timesteps]
positions = (
torch.arange(1, timesteps + 1, device=x.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
position_encoding = self.encode(positions, input_dim, x.dtype)
# Apply dropout to the sum
return self.dropout(x + position_encoding)
class CTCTransformerModel(PreTrainedModel):
config_class = CTCTransformerConfig
def __init__(self, config):
super().__init__(config)
self.embed = nn.Embedding(
config.vocab_size + 1,
config.hidden_size,
padding_idx=config.vocab_size,
)
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=self.config.dropout,
activation="gelu",
batch_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=config.num_hidden_layers
)
self.pos_embed = SinusoidalPositionEncoder(
d_model=config.hidden_size, dropout_rate=config.dropout
)
self.norm = nn.LayerNorm(config.hidden_size)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(
self,
input_ids,
attention_mask=None,
labels=None,
):
# Embed the input tokens
x = self.embed(input_ids)
x = self.norm(x)
# Add positional embeddings
x = self.pos_embed(x)
# Create mask for transformer
if attention_mask is not None:
# PyTorch transformer expects mask where True indicates positions to be MASKED (padding)
# Transformers attention_mask uses:
# - 1 for tokens that are NOT MASKED (should be attended to)
# - 0 for tokens that ARE MASKED (padding)
# So, we need to invert the attention_mask to match PyTorch Transformer's expectation
src_key_padding_mask = attention_mask == 0
else:
src_key_padding_mask = None
# Pass through encoder with proper masking
x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
x = self.norm(x)
# Project to output labels
logits = self.classifier(x) # [B, T, num_labels]
loss = None
if labels is not None:
input_lengths = attention_mask.sum(-1)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = nn.functional.log_softmax(
logits, dim=-1, dtype=torch.float32
).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=0,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
return {"loss": loss, "logits": logits}
@torch.inference_mode()
def predict(self, input_ids: List[int]):
blank_id = self.config.blank_id
# Create attention mask with 1s (not masked) for all positions
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(
input_ids.device
)
with torch.no_grad():
x = self.embed(input_ids)
x = self.pos_embed(x) # Add positional embeddings
# Using the same masking convention as forward method
encoded = self.encoder(x, src_key_padding_mask=(attention_mask == 0))
logits = self.classifier(encoded) # [1, T, V]
log_probs = F.log_softmax(logits, dim=-1) # [1, T, V]
pred_ids = torch.argmax(log_probs, dim=-1).squeeze(0).tolist()
# Greedy decode with collapse
pred_phoneme_ids = []
prev = None
for idx in pred_ids:
if idx != blank_id and idx != prev:
pred_phoneme_ids.append(idx)
prev = idx
return pred_phoneme_ids
def load_speech_tokenizer(speech_tokenizer_path: str):
"""Load speech tokenizer ONNX model."""
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
session = onnxruntime.InferenceSession(
speech_tokenizer_path,
sess_options=option,
providers=["CPUExecutionProvider"],
)
return session
def extract_speech_token(audio, speech_tokenizer_session):
"""
Extract speech tokens from audio using speech tokenizer.
Args:
audio: audio signal (torch.Tensor or numpy.ndarray), shape (T,) at 16kHz
speech_tokenizer_session: ONNX speech tokenizer session
Returns:
speech_token: tensor of shape (1, num_tokens)
speech_token_len: tensor of shape (1,) with token sequence length
"""
# Ensure audio is on CPU for processing
if isinstance(audio, torch.Tensor):
audio = audio.cpu().numpy()
elif isinstance(audio, np.ndarray):
pass
else:
raise ValueError("Audio must be torch.Tensor or numpy.ndarray")
# Convert to torch tensor for mel-spectrogram
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)
# Extract mel-spectrogram (whisper format)
feat = whisper.log_mel_spectrogram(audio_tensor, n_mels=128)
# Run speech tokenizer
speech_token = (
speech_tokenizer_session.run(
None,
{
speech_tokenizer_session.get_inputs()[0]
.name: feat.detach()
.cpu()
.numpy(),
speech_tokenizer_session.get_inputs()[1].name: np.array(
[feat.shape[2]], dtype=np.int32
),
},
)[0]
.flatten()
.tolist()
)
speech_token = torch.tensor([speech_token], dtype=torch.int32)
speech_token_len = torch.tensor([len(speech_token[0])], dtype=torch.int32)
return speech_token, speech_token_len
class EndpointHandler:
def __init__(self, model_dir: str, **kwargs: Any):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.speech_tokenizer_session = load_speech_tokenizer(
f"{model_dir}/speech_tokenizer_v2.onnx"
)
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_dir)
self.model = (
CTCTransformerModel.from_pretrained(
model_dir,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
.eval()
.to(device)
)
def preprocess(self, inputs):
waveform, original_sampling_rate = torchaudio.load(inputs)
if original_sampling_rate != 16000:
resampler = torchaudio.transforms.Resample(
orig_freq=original_sampling_rate, new_freq=16000
)
audio_array = resampler(waveform).numpy().flatten()
else:
audio_array = waveform.numpy().flatten()
return audio_array
def __call__(self, data: Dict[str, Any]) -> List[str]:
# get inputs, assuming a base64 encoded wav file
inputs = data.pop("inputs", data)
# decode base64 file and save to temp file
audio = inputs["audio"]
audio_bytes = base64.b64decode(audio)
temp_wav_path = "/tmp/temp.wav"
with open(temp_wav_path, "wb") as f:
f.write(audio_bytes)
audio_array = self.preprocess(temp_wav_path)
# Extract speech tokens
speech_token, speech_token_len = extract_speech_token(
audio_array, self.speech_tokenizer_session
)
with torch.no_grad():
speech_token = speech_token.to(next(self.model.parameters()).device)
outputs = self.model.predict(speech_token)
transcription = self.tokenizer.decode(outputs, skip_special_tokens=True)
print(transcription)
transcription = " ".join(
[parse_jyutping(jyt) for jyt in transcription.split(" ")]
)
return {"transcription": transcription}