File size: 5,296 Bytes
c2f5a97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Custom inference handler for HuggingFace Inference Endpoints."""

from typing import Any, Dict, List, Union

import torch

try:
    # For remote execution, imports are relative
    from .asr_modeling import ASRModel
    from .asr_pipeline import ASRPipeline
except ImportError:
    # For local execution, imports are not relative
    from asr_modeling import ASRModel  # type: ignore[no-redef]
    from asr_pipeline import ASRPipeline  # type: ignore[no-redef]


class EndpointHandler:
    def __init__(self, path: str = ""):
        import os

        import nltk

        nltk.download("punkt_tab", quiet=True)

        os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

        # Enable TF32 for faster matmul on Ampere+ GPUs (A100, etc.)
        # Also beneficial for T4 (Turing) which supports TensorFloat-32
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

        # Set device and dtype
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Use float16 for better T4 compatibility (bfloat16 not well supported on T4)
        # T4 has excellent float16 performance with tensor cores
        self.dtype = torch.float16 if self.device == "cuda" else torch.float32

        # Enable CUDA optimizations
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True

        # Prepare model kwargs for pipeline
        model_kwargs = {
            "dtype": self.dtype,
            "low_cpu_mem_usage": True,
        }
        if torch.cuda.is_available():
            model_kwargs["attn_implementation"] = (
                "flash_attention_2" if self._is_flash_attn_available() else "sdpa"
            )

        # Load model (this loads the model, tokenizer, and feature extractor)
        self.model = ASRModel.from_pretrained(path, **model_kwargs)

        # Instantiate custom pipeline - it will get feature_extractor and tokenizer from model
        self.pipe = ASRPipeline(
            model=self.model,
            feature_extractor=self.model.feature_extractor,
            tokenizer=self.model.tokenizer,
            device=self.device,
        )

        # Apply torch.compile if enabled (after model is loaded by pipeline)
        # Use "default" mode for T4 - better compatibility than "reduce-overhead"
        # "reduce-overhead" is better for A100+ but can be slower on older GPUs
        if torch.cuda.is_available() and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1":
            compile_mode = os.getenv("TORCH_COMPILE_MODE", "default")
            self.model = torch.compile(self.model, mode=compile_mode)
            self.pipe.model = self.model

        # Warmup the model to trigger compilation and optimize kernels
        if torch.cuda.is_available():
            self._warmup()

    def _is_flash_attn_available(self):
        """Check if flash attention is available."""
        import importlib.util

        return importlib.util.find_spec("flash_attn") is not None

    def _warmup(self):
        """Warmup to trigger model compilation and allocate GPU memory."""
        try:
            # Create dummy audio (1 second at config sample rate)
            sample_rate = self.pipe.model.config.audio_sample_rate
            dummy_audio = torch.randn(sample_rate, dtype=torch.float32)

            # Run inference to trigger torch.compile and kernel optimization
            with torch.inference_mode():
                warmup_tokens = self.pipe.model.config.inference_warmup_tokens
                _ = self.pipe(
                    {"raw": dummy_audio, "sampling_rate": sample_rate},
                    max_new_tokens=warmup_tokens,
                )

            # Force CUDA synchronization to ensure kernels are compiled
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                # Clear cache after warmup to free memory
                torch.cuda.empty_cache()

        except Exception as e:
            print(f"Warmup skipped due to: {e}")

    def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
        inputs = data.get("inputs")
        if inputs is None:
            raise ValueError("Missing 'inputs' in request data")

        params = data.get("parameters", {})
        max_new_tokens = params.get("max_new_tokens", 128)
        num_beams = params.get("num_beams", 1)
        do_sample = params.get("do_sample", False)
        length_penalty = params.get("length_penalty", 1.0)
        repetition_penalty = params.get("repetition_penalty", 1.05)
        no_repeat_ngram_size = params.get("no_repeat_ngram_size", 0)
        early_stopping = params.get("early_stopping", True)
        default_diversity = self.pipe.model.config.inference_diversity_penalty
        diversity_penalty = params.get("diversity_penalty", default_diversity)

        return self.pipe(
            inputs,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            do_sample=do_sample,
            length_penalty=length_penalty,
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            early_stopping=early_stopping,
            diversity_penalty=diversity_penalty,
        )