File size: 5,296 Bytes
c2f5a97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""Custom inference handler for HuggingFace Inference Endpoints."""
from typing import Any, Dict, List, Union
import torch
try:
# For remote execution, imports are relative
from .asr_modeling import ASRModel
from .asr_pipeline import ASRPipeline
except ImportError:
# For local execution, imports are not relative
from asr_modeling import ASRModel # type: ignore[no-redef]
from asr_pipeline import ASRPipeline # type: ignore[no-redef]
class EndpointHandler:
def __init__(self, path: str = ""):
import os
import nltk
nltk.download("punkt_tab", quiet=True)
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# Enable TF32 for faster matmul on Ampere+ GPUs (A100, etc.)
# Also beneficial for T4 (Turing) which supports TensorFloat-32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Set device and dtype
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Use float16 for better T4 compatibility (bfloat16 not well supported on T4)
# T4 has excellent float16 performance with tensor cores
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
# Enable CUDA optimizations
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
# Prepare model kwargs for pipeline
model_kwargs = {
"dtype": self.dtype,
"low_cpu_mem_usage": True,
}
if torch.cuda.is_available():
model_kwargs["attn_implementation"] = (
"flash_attention_2" if self._is_flash_attn_available() else "sdpa"
)
# Load model (this loads the model, tokenizer, and feature extractor)
self.model = ASRModel.from_pretrained(path, **model_kwargs)
# Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
self.pipe = ASRPipeline(
model=self.model,
feature_extractor=self.model.feature_extractor,
tokenizer=self.model.tokenizer,
device=self.device,
)
# Apply torch.compile if enabled (after model is loaded by pipeline)
# Use "default" mode for T4 - better compatibility than "reduce-overhead"
# "reduce-overhead" is better for A100+ but can be slower on older GPUs
if torch.cuda.is_available() and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1":
compile_mode = os.getenv("TORCH_COMPILE_MODE", "default")
self.model = torch.compile(self.model, mode=compile_mode)
self.pipe.model = self.model
# Warmup the model to trigger compilation and optimize kernels
if torch.cuda.is_available():
self._warmup()
def _is_flash_attn_available(self):
"""Check if flash attention is available."""
import importlib.util
return importlib.util.find_spec("flash_attn") is not None
def _warmup(self):
"""Warmup to trigger model compilation and allocate GPU memory."""
try:
# Create dummy audio (1 second at config sample rate)
sample_rate = self.pipe.model.config.audio_sample_rate
dummy_audio = torch.randn(sample_rate, dtype=torch.float32)
# Run inference to trigger torch.compile and kernel optimization
with torch.inference_mode():
warmup_tokens = self.pipe.model.config.inference_warmup_tokens
_ = self.pipe(
{"raw": dummy_audio, "sampling_rate": sample_rate},
max_new_tokens=warmup_tokens,
)
# Force CUDA synchronization to ensure kernels are compiled
if torch.cuda.is_available():
torch.cuda.synchronize()
# Clear cache after warmup to free memory
torch.cuda.empty_cache()
except Exception as e:
print(f"Warmup skipped due to: {e}")
def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
inputs = data.get("inputs")
if inputs is None:
raise ValueError("Missing 'inputs' in request data")
params = data.get("parameters", {})
max_new_tokens = params.get("max_new_tokens", 128)
num_beams = params.get("num_beams", 1)
do_sample = params.get("do_sample", False)
length_penalty = params.get("length_penalty", 1.0)
repetition_penalty = params.get("repetition_penalty", 1.05)
no_repeat_ngram_size = params.get("no_repeat_ngram_size", 0)
early_stopping = params.get("early_stopping", True)
default_diversity = self.pipe.model.config.inference_diversity_penalty
diversity_penalty = params.get("diversity_penalty", default_diversity)
return self.pipe(
inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
do_sample=do_sample,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=early_stopping,
diversity_penalty=diversity_penalty,
)
|