| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor |
| from whisper.model import AudioEncoder, sinusoids, Whisper, ModelDimensions |
|
|
|
|
| class AudioEncoder_(AudioEncoder): |
| def __init__(self, *args, **kwargs): |
| super(AudioEncoder_, self).__init__(*args, **kwargs) |
|
|
| def extract_feature(self, x: Tensor, target_layer: Optional[int] = None): |
| """ |
| x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) |
| the mel spectrogram of the audio |
| """ |
| x = F.gelu(self.conv1(x)) |
| x = F.gelu(self.conv2(x)) |
| x = x.permute(0, 2, 1) |
|
|
| length_x = x.shape[1] |
| if length_x > self.positional_embedding.shape[0]: |
| self.register_buffer("positional_embedding", sinusoids(length_x, self.positional_embedding.shape[1])) |
| self.positional_embedding = self.positional_embedding.to(x.device) |
| x = (x + self.positional_embedding[:length_x, :]).to(x.dtype) |
|
|
| if target_layer is None: |
| target_layer = len(self.blocks) |
|
|
| for block in self.blocks[:target_layer]: |
| x = block(x) |
|
|
| return x |
|
|
|
|
| class Whisper_(Whisper): |
| def __init__(self, dims: ModelDimensions): |
| super(Whisper_, self).__init__(dims) |
| |
| self.encoder = AudioEncoder_( |
| self.dims.n_mels, |
| self.dims.n_audio_ctx, |
| self.dims.n_audio_state, |
| self.dims.n_audio_head, |
| self.dims.n_audio_layer, |
| ) |
|
|
| def extract_features(self, mel: torch.Tensor, target_layer: Optional[int] = None): |
| return self.encoder.extract_feature(mel, target_layer) |
|
|