| |
| |
| |
| |
| |
| |
| |
|
|
| import io |
| import logging |
| import os |
| from typing import Optional, Union |
|
|
| import soundfile as sf |
| import torch |
| from whisper import _MODELS, _download, _ALIGNMENT_HEADS, available_models |
| from whisper.audio import log_mel_spectrogram |
| from whisper.model import ModelDimensions |
|
|
| from whisper_model import Whisper_ |
|
|
| logger = logging.getLogger("dump_feature") |
|
|
|
|
| def load_model( |
| name: str, |
| device: Optional[Union[str, torch.device]] = None, |
| download_root: str = None, |
| in_memory: bool = False, |
| ) -> Whisper_: |
| """ |
| Reference: https://github.com/openai/whisper/blob/main/whisper/__init__.py#L97 |
| But we will load a `Whisper_` model for feature extraction. |
| |
| Parameters |
| ---------- |
| name : str |
| one of the official model names listed by `whisper.available_models()`, or |
| path to a model checkpoint containing the model dimensions and the model state_dict. |
| device : Union[str, torch.device] |
| the PyTorch device to put the model into |
| download_root: str |
| path to download the model files; by default, it uses "~/.cache/whisper" |
| in_memory: bool |
| whether to preload the model weights into host memory |
| |
| Returns |
| ------- |
| model : Whisper |
| The Whisper ASR model instance |
| """ |
|
|
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if download_root is None: |
| default = os.path.join(os.path.expanduser("~"), ".cache") |
| download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") |
|
|
| if name in _MODELS: |
| checkpoint_file = _download(_MODELS[name], download_root, in_memory) |
| alignment_heads = _ALIGNMENT_HEADS[name] |
| elif os.path.isfile(name): |
| checkpoint_file = open(name, "rb").read() if in_memory else name |
| alignment_heads = None |
| else: |
| raise RuntimeError( |
| f"Model {name} not found; available models = {available_models()}" |
| ) |
|
|
| with ( |
| io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") |
| ) as fp: |
| checkpoint = torch.load(fp, map_location=device) |
| del checkpoint_file |
|
|
| dims = ModelDimensions(**checkpoint["dims"]) |
| model = Whisper_(dims) |
| model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
| if alignment_heads is not None: |
| model.set_alignment_heads(alignment_heads) |
|
|
| return model.to(device) |
|
|
|
|
| class WhisperFeatureReader(object): |
| def __init__(self, root, ckpt, layer, device): |
| self.device = device |
| logger.info(f"device = {self.device}") |
|
|
| self.model: Whisper_ = load_model(name=ckpt, device=self.device, download_root=root).eval() |
| self.model.decoder = None |
| self.layer = layer |
|
|
| def read_audio(self, path, ref_len=None): |
| wav, sample_rate = sf.read(path) |
| assert sample_rate == 16000, sample_rate |
| if ref_len is not None and abs(ref_len - len(wav)) > 160: |
| logger.warning(f"ref {ref_len} != read {len(wav)} ({path})") |
| return wav |
|
|
| def get_feats(self, path, ref_len=None): |
| wav = self.read_audio(path, ref_len) |
| audio_length = len(wav) |
| with torch.no_grad(): |
| mel = log_mel_spectrogram(torch.from_numpy(wav).float().to(self.device)) |
| hidden = self.model.extract_features(mel.unsqueeze(0), target_layer=self.layer) |
| feature_length = audio_length // 320 |
| hidden = hidden[0, :feature_length] |
| return hidden.contiguous() |
|
|