tiny-audio-residual / asr_pipeline.py
mazesmazes's picture
Training in progress - step 500
09b02b0 verified
from collections.abc import Iterator
from pathlib import Path
from typing import Any
import torch
import transformers
from truecase import get_true_case
from .asr_modeling import ASRModel
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
"""ASR Pipeline for audio-to-text transcription."""
model: ASRModel
def __init__(self, model: ASRModel, **kwargs):
feature_extractor = kwargs.pop("feature_extractor", None)
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
# Get feature extractor from model's processor if not provided
if feature_extractor is None:
processor = model.get_processor()
feature_extractor = processor.feature_extractor
super().__init__(
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
)
# Initialize text normalizer
if hasattr(tokenizer, "normalize"):
self.text_normalizer = tokenizer
else:
from transformers import WhisperTokenizer
self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
def __call__(self, inputs, **kwargs):
generate_kwargs = {}
generate_keys = [
"max_new_tokens",
"num_beams",
"do_sample",
"length_penalty",
"repetition_penalty",
"no_repeat_ngram_size",
"early_stopping",
"num_beam_groups",
"diversity_penalty",
"top_k",
"temperature",
"top_p",
"user_prompt",
"task",
"text_input",
]
for key in generate_keys:
if key in kwargs:
generate_kwargs[key] = kwargs.pop(key)
# Handle text-only mode
task = generate_kwargs.get("task")
if task == "text" or generate_kwargs.get("text_input"):
return self._process_text_only(generate_kwargs)
# Handle list inputs
if isinstance(inputs, list):
return [self.__call__(inp, **kwargs, **generate_kwargs) for inp in inputs]
model_inputs = self.preprocess(inputs, **kwargs)
if isinstance(model_inputs, Iterator):
return self._process_chunks(list(model_inputs), generate_kwargs)
model_outputs = self._forward(model_inputs, **generate_kwargs)
return self.postprocess(model_outputs)
def _process_chunks(self, chunks: list, generate_kwargs: dict) -> dict[str, str]:
"""Process chunked audio and merge results."""
all_tokens: list[int] = []
for chunk in chunks:
output = self._forward(chunk, **generate_kwargs)
tokens = output.get("tokens")
if tokens is None:
tokens = output.get("generated_ids")
if tokens is not None:
if torch.is_tensor(tokens):
tokens = tokens.cpu()
if len(tokens.shape) > 1:
tokens = tokens[0]
all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens)
text = self.tokenizer.decode(all_tokens, skip_special_tokens=True).strip()
text = self.text_normalizer.normalize(text)
text = get_true_case(text)
return {"text": text}
def preprocess(self, inputs, **preprocess_params):
if isinstance(inputs, list):
raise ValueError("Lists should not reach preprocess")
preprocess_params.setdefault("chunk_length_s", 0)
# Normalize input formats
if isinstance(inputs, dict):
if "bytes" in inputs:
inputs = self._decode_audio_bytes(inputs["bytes"])
elif "array" in inputs:
inputs = {"raw": inputs["array"], "sampling_rate": inputs["sampling_rate"]}
elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"):
inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate}
elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)):
inputs = {"raw": inputs, "sampling_rate": self.model.config.audio_sample_rate}
elif torch.is_tensor(inputs):
inputs = {
"raw": inputs.cpu().numpy(),
"sampling_rate": self.model.config.audio_sample_rate,
}
return super().preprocess(inputs, **preprocess_params)
def _decode_audio_bytes(self, wav_bytes: bytes) -> dict[str, Any]:
"""Decode audio bytes to array format."""
import tempfile
from torchcodec.decoders import AudioDecoder
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
f.write(wav_bytes)
temp_path = f.name
try:
decoder = AudioDecoder(temp_path)
audio_result = decoder.get_all_samples()
return {
"raw": audio_result.data.squeeze().numpy(),
"sampling_rate": audio_result.sample_rate,
}
finally:
Path(temp_path).unlink()
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
task: str | None = generate_kwargs.pop("task", None)
# Task-specific defaults
task_params: dict[str, dict[str, Any]] = {
"transcribe": {"do_sample": False},
"emotion": {"do_sample": True, "temperature": 0.7},
"describe": {"do_sample": True, "temperature": 0.7},
"continue": {"do_sample": True, "temperature": 1.0},
}
if task is not None and task in task_params:
for key, value in task_params[task].items():
generate_kwargs.setdefault(key, value)
# Extract audio from model_inputs
audio_inputs, is_whisper = self._extract_audio(model_inputs)
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
# Generation defaults
generate_kwargs.setdefault(
"eos_token_id", self.model.tokenizer.convert_tokens_to_ids("<|im_end|>")
)
generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
# Generate
if is_whisper:
generated_ids = self.model.generate(
input_features=audio_inputs,
task=task,
**generate_kwargs,
)
else:
generated_ids = self.model.generate(
input_values=audio_inputs,
task=task,
**generate_kwargs,
)
return {"tokens": generated_ids, "is_last": is_last}
def _extract_audio(self, model_inputs) -> tuple[torch.Tensor, bool]:
"""Extract audio tensor from various input formats."""
if isinstance(model_inputs, torch.Tensor):
return model_inputs.to(self.model.device), False
if isinstance(model_inputs, (list, tuple)) and model_inputs:
model_inputs = (
model_inputs[0]
if isinstance(model_inputs[0], dict)
else {"input_values": model_inputs[0]}
)
if isinstance(model_inputs, dict):
model_inputs.pop("stride", None)
if "input_features" in model_inputs:
return model_inputs["input_features"].to(self.model.device), True
if "input_values" in model_inputs:
return model_inputs["input_values"].to(self.model.device), False
raise ValueError(f"Could not extract audio from {type(model_inputs)}")
def _process_text_only(self, generate_kwargs: dict) -> dict[str, str]:
"""Process text-only input without audio."""
text_input = generate_kwargs.pop("text_input", None)
if text_input is None:
raise ValueError("text_input required for text task")
generate_kwargs.pop("task", None)
generated_ids = self.model.generate(task="text", text_input=text_input, **generate_kwargs)
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return {"text": text}
def postprocess(
self, model_outputs: dict[str, Any], return_timestamps=None, return_language=None
) -> dict[str, str]:
if isinstance(model_outputs, list):
for output in model_outputs:
for key, value in output.items():
if torch.is_tensor(value):
output[key] = value.cpu()
return super().postprocess(model_outputs)
model_outputs.pop("is_last", None)
tokens = model_outputs.get("tokens") or model_outputs.get("generated_ids")
if tokens is None:
raise ValueError(f"Expected 'tokens' or 'generated_ids', got: {model_outputs.keys()}")
if torch.is_tensor(tokens) and tokens.device.type != "cpu":
tokens = tokens.cpu()
if len(tokens.shape) > 1:
tokens = tokens[0]
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
text = self.text_normalizer.normalize(text)
text = get_true_case(text)
return {"text": text}