|
|
"""Custom inference handler for HuggingFace Inference Endpoints.""" |
|
|
|
|
|
from typing import Any, Dict, List, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
try: |
|
|
|
|
|
from .asr_modeling import ASRModel |
|
|
from .asr_pipeline import ASRPipeline |
|
|
except ImportError: |
|
|
|
|
|
from asr_modeling import ASRModel |
|
|
from asr_pipeline import ASRPipeline |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
import os |
|
|
|
|
|
import nltk |
|
|
|
|
|
nltk.download("punkt_tab", quiet=True) |
|
|
|
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
|
|
self.dtype = torch.float16 if self.device == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
|
model_kwargs = { |
|
|
"dtype": self.dtype, |
|
|
"low_cpu_mem_usage": True, |
|
|
} |
|
|
if torch.cuda.is_available(): |
|
|
model_kwargs["attn_implementation"] = ( |
|
|
"flash_attention_2" if self._is_flash_attn_available() else "sdpa" |
|
|
) |
|
|
|
|
|
|
|
|
self.model = ASRModel.from_pretrained(path, **model_kwargs) |
|
|
|
|
|
|
|
|
self.pipe = ASRPipeline( |
|
|
model=self.model, |
|
|
feature_extractor=self.model.feature_extractor, |
|
|
tokenizer=self.model.tokenizer, |
|
|
device=self.device, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available() and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1": |
|
|
compile_mode = os.getenv("TORCH_COMPILE_MODE", "default") |
|
|
self.model = torch.compile(self.model, mode=compile_mode) |
|
|
self.pipe.model = self.model |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
self._warmup() |
|
|
|
|
|
def _is_flash_attn_available(self): |
|
|
"""Check if flash attention is available.""" |
|
|
import importlib.util |
|
|
|
|
|
return importlib.util.find_spec("flash_attn") is not None |
|
|
|
|
|
def _warmup(self): |
|
|
"""Warmup to trigger model compilation and allocate GPU memory.""" |
|
|
try: |
|
|
|
|
|
sample_rate = self.pipe.model.config.audio_sample_rate |
|
|
dummy_audio = torch.randn(sample_rate, dtype=torch.float32) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
warmup_tokens = self.pipe.model.config.inference_warmup_tokens |
|
|
_ = self.pipe( |
|
|
{"raw": dummy_audio, "sampling_rate": sample_rate}, |
|
|
max_new_tokens=warmup_tokens, |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warmup skipped due to: {e}") |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]: |
|
|
inputs = data.get("inputs") |
|
|
if inputs is None: |
|
|
raise ValueError("Missing 'inputs' in request data") |
|
|
|
|
|
params = data.get("parameters", {}) |
|
|
max_new_tokens = params.get("max_new_tokens", 128) |
|
|
num_beams = params.get("num_beams", 1) |
|
|
do_sample = params.get("do_sample", False) |
|
|
length_penalty = params.get("length_penalty", 1.0) |
|
|
repetition_penalty = params.get("repetition_penalty", 1.05) |
|
|
no_repeat_ngram_size = params.get("no_repeat_ngram_size", 0) |
|
|
early_stopping = params.get("early_stopping", True) |
|
|
default_diversity = self.pipe.model.config.inference_diversity_penalty |
|
|
diversity_penalty = params.get("diversity_penalty", default_diversity) |
|
|
|
|
|
return self.pipe( |
|
|
inputs, |
|
|
max_new_tokens=max_new_tokens, |
|
|
num_beams=num_beams, |
|
|
do_sample=do_sample, |
|
|
length_penalty=length_penalty, |
|
|
repetition_penalty=repetition_penalty, |
|
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
|
early_stopping=early_stopping, |
|
|
diversity_penalty=diversity_penalty, |
|
|
) |
|
|
|