|
|
from collections.abc import Iterator |
|
|
from pathlib import Path |
|
|
from typing import Any |
|
|
|
|
|
import torch |
|
|
import transformers |
|
|
from truecase import get_true_case |
|
|
|
|
|
from .asr_modeling import ASRModel |
|
|
|
|
|
|
|
|
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline): |
|
|
"""ASR Pipeline for audio-to-text transcription.""" |
|
|
|
|
|
model: ASRModel |
|
|
|
|
|
def __init__(self, model: ASRModel, **kwargs): |
|
|
feature_extractor = kwargs.pop("feature_extractor", None) |
|
|
tokenizer = kwargs.pop("tokenizer", model.tokenizer) |
|
|
|
|
|
|
|
|
if feature_extractor is None: |
|
|
processor = model.get_processor() |
|
|
feature_extractor = processor.feature_extractor |
|
|
|
|
|
super().__init__( |
|
|
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(tokenizer, "normalize"): |
|
|
self.text_normalizer = tokenizer |
|
|
else: |
|
|
from transformers import WhisperTokenizer |
|
|
|
|
|
self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny") |
|
|
|
|
|
def __call__(self, inputs, **kwargs): |
|
|
generate_kwargs = {} |
|
|
generate_keys = [ |
|
|
"max_new_tokens", |
|
|
"num_beams", |
|
|
"do_sample", |
|
|
"length_penalty", |
|
|
"repetition_penalty", |
|
|
"no_repeat_ngram_size", |
|
|
"early_stopping", |
|
|
"num_beam_groups", |
|
|
"diversity_penalty", |
|
|
"top_k", |
|
|
"temperature", |
|
|
"top_p", |
|
|
"user_prompt", |
|
|
"task", |
|
|
"text_input", |
|
|
] |
|
|
for key in generate_keys: |
|
|
if key in kwargs: |
|
|
generate_kwargs[key] = kwargs.pop(key) |
|
|
|
|
|
|
|
|
task = generate_kwargs.get("task") |
|
|
if task == "text" or generate_kwargs.get("text_input"): |
|
|
return self._process_text_only(generate_kwargs) |
|
|
|
|
|
|
|
|
if isinstance(inputs, list): |
|
|
return [self.__call__(inp, **kwargs, **generate_kwargs) for inp in inputs] |
|
|
|
|
|
model_inputs = self.preprocess(inputs, **kwargs) |
|
|
|
|
|
if isinstance(model_inputs, Iterator): |
|
|
return self._process_chunks(list(model_inputs), generate_kwargs) |
|
|
|
|
|
model_outputs = self._forward(model_inputs, **generate_kwargs) |
|
|
return self.postprocess(model_outputs) |
|
|
|
|
|
def _process_chunks(self, chunks: list, generate_kwargs: dict) -> dict[str, str]: |
|
|
"""Process chunked audio and merge results.""" |
|
|
all_tokens: list[int] = [] |
|
|
|
|
|
for chunk in chunks: |
|
|
output = self._forward(chunk, **generate_kwargs) |
|
|
tokens = output.get("tokens") |
|
|
if tokens is None: |
|
|
tokens = output.get("generated_ids") |
|
|
if tokens is not None: |
|
|
if torch.is_tensor(tokens): |
|
|
tokens = tokens.cpu() |
|
|
if len(tokens.shape) > 1: |
|
|
tokens = tokens[0] |
|
|
all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens) |
|
|
|
|
|
text = self.tokenizer.decode(all_tokens, skip_special_tokens=True).strip() |
|
|
text = self.text_normalizer.normalize(text) |
|
|
text = get_true_case(text) |
|
|
|
|
|
return {"text": text} |
|
|
|
|
|
def preprocess(self, inputs, **preprocess_params): |
|
|
if isinstance(inputs, list): |
|
|
raise ValueError("Lists should not reach preprocess") |
|
|
|
|
|
preprocess_params.setdefault("chunk_length_s", 0) |
|
|
|
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
if "bytes" in inputs: |
|
|
inputs = self._decode_audio_bytes(inputs["bytes"]) |
|
|
elif "array" in inputs: |
|
|
inputs = {"raw": inputs["array"], "sampling_rate": inputs["sampling_rate"]} |
|
|
elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"): |
|
|
inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate} |
|
|
elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)): |
|
|
inputs = {"raw": inputs, "sampling_rate": self.model.config.audio_sample_rate} |
|
|
elif torch.is_tensor(inputs): |
|
|
inputs = { |
|
|
"raw": inputs.cpu().numpy(), |
|
|
"sampling_rate": self.model.config.audio_sample_rate, |
|
|
} |
|
|
|
|
|
return super().preprocess(inputs, **preprocess_params) |
|
|
|
|
|
def _decode_audio_bytes(self, wav_bytes: bytes) -> dict[str, Any]: |
|
|
"""Decode audio bytes to array format.""" |
|
|
import tempfile |
|
|
|
|
|
from torchcodec.decoders import AudioDecoder |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
|
f.write(wav_bytes) |
|
|
temp_path = f.name |
|
|
|
|
|
try: |
|
|
decoder = AudioDecoder(temp_path) |
|
|
audio_result = decoder.get_all_samples() |
|
|
return { |
|
|
"raw": audio_result.data.squeeze().numpy(), |
|
|
"sampling_rate": audio_result.sample_rate, |
|
|
} |
|
|
finally: |
|
|
Path(temp_path).unlink() |
|
|
|
|
|
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]: |
|
|
task: str | None = generate_kwargs.pop("task", None) |
|
|
|
|
|
|
|
|
task_params: dict[str, dict[str, Any]] = { |
|
|
"transcribe": {"do_sample": False}, |
|
|
"emotion": {"do_sample": True, "temperature": 0.7}, |
|
|
"describe": {"do_sample": True, "temperature": 0.7}, |
|
|
"continue": {"do_sample": True, "temperature": 1.0}, |
|
|
} |
|
|
if task is not None and task in task_params: |
|
|
for key, value in task_params[task].items(): |
|
|
generate_kwargs.setdefault(key, value) |
|
|
|
|
|
|
|
|
audio_inputs, is_whisper = self._extract_audio(model_inputs) |
|
|
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True |
|
|
|
|
|
|
|
|
generate_kwargs.setdefault( |
|
|
"eos_token_id", self.model.tokenizer.convert_tokens_to_ids("<|im_end|>") |
|
|
) |
|
|
generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens) |
|
|
|
|
|
|
|
|
if is_whisper: |
|
|
generated_ids = self.model.generate( |
|
|
input_features=audio_inputs, |
|
|
task=task, |
|
|
**generate_kwargs, |
|
|
) |
|
|
else: |
|
|
generated_ids = self.model.generate( |
|
|
input_values=audio_inputs, |
|
|
task=task, |
|
|
**generate_kwargs, |
|
|
) |
|
|
|
|
|
return {"tokens": generated_ids, "is_last": is_last} |
|
|
|
|
|
def _extract_audio(self, model_inputs) -> tuple[torch.Tensor, bool]: |
|
|
"""Extract audio tensor from various input formats.""" |
|
|
if isinstance(model_inputs, torch.Tensor): |
|
|
return model_inputs.to(self.model.device), False |
|
|
|
|
|
if isinstance(model_inputs, (list, tuple)) and model_inputs: |
|
|
model_inputs = ( |
|
|
model_inputs[0] |
|
|
if isinstance(model_inputs[0], dict) |
|
|
else {"input_values": model_inputs[0]} |
|
|
) |
|
|
|
|
|
if isinstance(model_inputs, dict): |
|
|
model_inputs.pop("stride", None) |
|
|
if "input_features" in model_inputs: |
|
|
return model_inputs["input_features"].to(self.model.device), True |
|
|
if "input_values" in model_inputs: |
|
|
return model_inputs["input_values"].to(self.model.device), False |
|
|
|
|
|
raise ValueError(f"Could not extract audio from {type(model_inputs)}") |
|
|
|
|
|
def _process_text_only(self, generate_kwargs: dict) -> dict[str, str]: |
|
|
"""Process text-only input without audio.""" |
|
|
text_input = generate_kwargs.pop("text_input", None) |
|
|
if text_input is None: |
|
|
raise ValueError("text_input required for text task") |
|
|
|
|
|
generate_kwargs.pop("task", None) |
|
|
generated_ids = self.model.generate(task="text", text_input=text_input, **generate_kwargs) |
|
|
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
return {"text": text} |
|
|
|
|
|
def postprocess( |
|
|
self, model_outputs: dict[str, Any], return_timestamps=None, return_language=None |
|
|
) -> dict[str, str]: |
|
|
if isinstance(model_outputs, list): |
|
|
for output in model_outputs: |
|
|
for key, value in output.items(): |
|
|
if torch.is_tensor(value): |
|
|
output[key] = value.cpu() |
|
|
return super().postprocess(model_outputs) |
|
|
|
|
|
model_outputs.pop("is_last", None) |
|
|
tokens = model_outputs.get("tokens") or model_outputs.get("generated_ids") |
|
|
|
|
|
if tokens is None: |
|
|
raise ValueError(f"Expected 'tokens' or 'generated_ids', got: {model_outputs.keys()}") |
|
|
|
|
|
if torch.is_tensor(tokens) and tokens.device.type != "cpu": |
|
|
tokens = tokens.cpu() |
|
|
if len(tokens.shape) > 1: |
|
|
tokens = tokens[0] |
|
|
|
|
|
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() |
|
|
text = self.text_normalizer.normalize(text) |
|
|
text = get_true_case(text) |
|
|
|
|
|
return {"text": text} |
|
|
|