|
|
| import torch |
| import torchaudio |
| import torch.nn as nn |
|
|
| import modules.p1cupe.model_utils as model_utils |
|
|
|
|
|
|
| class FinetuneXLSR(nn.Module): |
| def __init__(self, hp, input_wav_length, freeze_feature_encoder=False): |
| super().__init__() |
|
|
| self.hp = hp |
| self.noise_level = hp.noise_level |
| self.xls_dim = 1024 |
| self.output_dim = hp.phoneme_classes + 1 |
| |
| |
| |
| bundle = torchaudio.pipelines.WAV2VEC2_XLSR_300M |
| |
| self.XLSR = bundle.get_model() |
|
|
| |
|
|
| |
| def reset_parameters(module): |
| if hasattr(module, 'reset_parameters'): |
| module.reset_parameters() |
| |
|
|
| if hasattr(self.hp , 'xlrs_reset'): |
| if (self.hp.xlrs_reset): |
| print("reset_parameters for XLSR") |
| self.XLSR.apply(reset_parameters) |
| |
| self.freeze_feature_encoder = freeze_feature_encoder |
| |
| |
| if self.freeze_feature_encoder: |
| |
| for param in self.XLSR.model.feature_extractor.parameters(): |
| param.requires_grad = False |
| |
| |
| |
| |
| |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(self.xls_dim, self.xls_dim), |
| nn.ReLU(), |
| nn.Dropout(0.25), |
| nn.Linear(self.xls_dim, self.output_dim) |
| ) |
|
|
| |
| |
| self.layer_dims = model_utils.ModelUtils.extract_layer_dims(self) |
| self.frames_per_window = model_utils.ModelUtils.calculate_layer_sizes(self.layer_dims, torch.tensor([input_wav_length]), -1)[0].int() |
| self.frames_per_window = torch.ceil((self.frames_per_window-1)).int() |
| self.model_utils = model_utils.ModelUtils(self.layer_dims, input_wav_length, self.frames_per_window) |
| |
| |
|
|
| |
| def update_frames_per_window(self, input_wav_length): |
| self.frames_per_window = self.model_utils.calculate_layer_sizes(self.layer_dims, torch.tensor([input_wav_length]), -1)[0].int() |
| self.frames_per_window = torch.ceil((self.frames_per_window-1)).int() |
| print("frames_per_window (frames per clip if disable_windowing):", self.frames_per_window.item()) |
| return self.frames_per_window |
|
|
| def forward(self, x): |
| if self.training and self.noise_level > 0: |
| x = x + torch.randn_like(x) * self.noise_level |
| |
| |
| |
| |
| |
| |
| |
| |
| features, _ = self.XLSR.extract_features(x) |
| features = features[-1] |
| |
| logits = self.classifier(features) |
| return logits |
|
|
|
|
|
|
|
|
| def test_model(): |
| |
|
|
| |
| |
|
|
| device = torch.device("cuda:0") |
| |
| print(torch.__version__) |
| print(torchaudio.__version__) |
| torch.random.manual_seed(0) |
| print(device) |
|
|
| |
| model = FinetuneXLSR(noise_level=0.01, freeze_feature_encoder=False).to(device) |
|
|
| print(model.__class__) |
|
|
| |
| sample_path = "tmp/data/audio_samples/9860_8338_000010.flac.wav" |
|
|
| sample_waveform, sample_samplerate = torchaudio.load(sample_path) |
| sample_waveform = sample_waveform.to(device) |
| waveform = torchaudio.functional.resample(sample_waveform, sample_samplerate, 16000) |
| |
| batch = waveform.to(device) |
| print(waveform.shape) |
| |
| |
| with torch.inference_mode(): |
| logits = model(batch) |
|
|
| print(len(logits), logits[0].shape) |
| for x, element in enumerate(logits): |
| print(x, element) |
|
|
| def main(): |
| test_model() |
|
|
| if __name__ == "__main__": |
| main() |