Spaces:
Build error
Build error
cantabile-kwok
commited on
Commit
·
05005db
1
Parent(s):
8bd60fe
prepare demo page
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +51 -0
- pretrained/WavLM-Large.pt +3 -0
- pretrained/config.yml +201 -0
- pretrained/generator.ckpt +3 -0
- pretrained/vq-wav2vec_kmeans.pt +3 -0
- requirements.txt +25 -0
- vec2wav2/__init__.py +3 -0
- vec2wav2/__pycache__/__init__.cpython-310.pyc +0 -0
- vec2wav2/__pycache__/__init__.cpython-311.pyc +0 -0
- vec2wav2/__pycache__/__init__.cpython-39.pyc +0 -0
- vec2wav2/bin/.DS_Store +0 -0
- vec2wav2/bin/__init__.py +0 -0
- vec2wav2/bin/__pycache__/__init__.cpython-310.pyc +0 -0
- vec2wav2/bin/__pycache__/vc.cpython-310.pyc +0 -0
- vec2wav2/bin/decode.py +163 -0
- vec2wav2/bin/gradio_app.py +51 -0
- vec2wav2/bin/train.py +1007 -0
- vec2wav2/bin/vc.py +128 -0
- vec2wav2/datasets/__init__.py +1 -0
- vec2wav2/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- vec2wav2/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- vec2wav2/datasets/__pycache__/scp_dataset.cpython-310.pyc +0 -0
- vec2wav2/datasets/__pycache__/scp_dataset.cpython-39.pyc +0 -0
- vec2wav2/datasets/scp_dataset.py +300 -0
- vec2wav2/distributed/__init__.py +0 -0
- vec2wav2/distributed/launch.py +163 -0
- vec2wav2/layers/__init__.py +6 -0
- vec2wav2/layers/__pycache__/__init__.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/__init__.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/activations.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/causal_conv.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/causal_conv.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/pqmf.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/pqmf.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/residual_block.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/residual_block.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/residual_stack.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/residual_stack.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/tade_res_block.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/tade_res_block.cpython-39.pyc +0 -0
- vec2wav2/layers/__pycache__/upsample.cpython-310.pyc +0 -0
- vec2wav2/layers/__pycache__/upsample.cpython-39.pyc +0 -0
- vec2wav2/layers/activations.py +197 -0
- vec2wav2/layers/causal_conv.py +66 -0
- vec2wav2/layers/pqmf.py +150 -0
- vec2wav2/layers/residual_block.py +222 -0
- vec2wav2/layers/residual_stack.py +85 -0
- vec2wav2/layers/tade_res_block.py +160 -0
- vec2wav2/layers/upsample.py +194 -0
- vec2wav2/losses/__init__.py +4 -0
app.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import logging
|
| 6 |
+
import yaml
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from vec2wav2.bin.vc import VoiceConverter, configure_logging, vc_args
|
| 11 |
+
|
| 12 |
+
# Create Gradio interface
|
| 13 |
+
def create_interface():
|
| 14 |
+
args = vc_args()
|
| 15 |
+
logger = configure_logging(args.verbose)
|
| 16 |
+
voice_converter = VoiceConverter(
|
| 17 |
+
expdir=args.expdir,
|
| 18 |
+
token_extractor=args.token_extractor,
|
| 19 |
+
prompt_extractor=args.prompt_extractor,
|
| 20 |
+
prompt_output_layer=args.prompt_output_layer,
|
| 21 |
+
checkpoint=args.checkpoint,
|
| 22 |
+
script_logger=logger
|
| 23 |
+
)
|
| 24 |
+
with gr.Blocks(title="Voice Conversion") as demo:
|
| 25 |
+
gr.Markdown("# vec2wav 2.0 Voice Conversion Demo")
|
| 26 |
+
gr.Markdown("Upload source audio and target speaker audio to convert the voice.")
|
| 27 |
+
|
| 28 |
+
with gr.Row():
|
| 29 |
+
source_audio = gr.Audio(label="Source Audio", type="filepath")
|
| 30 |
+
target_audio = gr.Audio(label="Target Speaker Audio", type="filepath")
|
| 31 |
+
|
| 32 |
+
examples = [
|
| 33 |
+
["examples/Zuckerberg.wav", "examples/Rachel.wav"],
|
| 34 |
+
["examples/TheresaMay.wav", "examples/OptimusPrime.wav"]
|
| 35 |
+
]
|
| 36 |
+
gr.Examples(examples, label="Examples", inputs=[source_audio, target_audio])
|
| 37 |
+
|
| 38 |
+
convert_btn = gr.Button("Convert Voice")
|
| 39 |
+
output_audio = gr.Audio(label="Converted Audio")
|
| 40 |
+
|
| 41 |
+
convert_btn.click(
|
| 42 |
+
fn=voice_converter.voice_conversion,
|
| 43 |
+
inputs=[source_audio, target_audio],
|
| 44 |
+
outputs=output_audio
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
return demo
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
demo = create_interface()
|
| 51 |
+
demo.launch(share=True)
|
pretrained/WavLM-Large.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6fb4b3c3e6aa567f0a997b30855859cb81528ee8078802af439f7b2da0bf100f
|
| 3 |
+
size 1261965425
|
pretrained/config.yml
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
allow_cache: false
|
| 2 |
+
batch_frames: 3600
|
| 3 |
+
config: conf/ctxv2w.v1.yaml
|
| 4 |
+
crop_max_frames: 100
|
| 5 |
+
discriminator_adv_loss_params:
|
| 6 |
+
average_by_discriminators: false
|
| 7 |
+
discriminator_grad_norm: -1
|
| 8 |
+
discriminator_optimizer_params:
|
| 9 |
+
betas:
|
| 10 |
+
- 0.5
|
| 11 |
+
- 0.9
|
| 12 |
+
lr: 0.0002
|
| 13 |
+
weight_decay: 0.0
|
| 14 |
+
discriminator_optimizer_type: Adam
|
| 15 |
+
discriminator_params:
|
| 16 |
+
follow_official_norm: true
|
| 17 |
+
period_discriminator_params:
|
| 18 |
+
bias: true
|
| 19 |
+
channels: 32
|
| 20 |
+
downsample_scales:
|
| 21 |
+
- 3
|
| 22 |
+
- 3
|
| 23 |
+
- 3
|
| 24 |
+
- 3
|
| 25 |
+
- 1
|
| 26 |
+
in_channels: 1
|
| 27 |
+
kernel_sizes:
|
| 28 |
+
- 5
|
| 29 |
+
- 3
|
| 30 |
+
max_downsample_channels: 1024
|
| 31 |
+
nonlinear_activation: LeakyReLU
|
| 32 |
+
nonlinear_activation_params:
|
| 33 |
+
negative_slope: 0.1
|
| 34 |
+
out_channels: 1
|
| 35 |
+
use_spectral_norm: false
|
| 36 |
+
use_weight_norm: true
|
| 37 |
+
periods:
|
| 38 |
+
- 2
|
| 39 |
+
- 3
|
| 40 |
+
- 5
|
| 41 |
+
- 7
|
| 42 |
+
- 11
|
| 43 |
+
scale_discriminator_params:
|
| 44 |
+
bias: true
|
| 45 |
+
channels: 128
|
| 46 |
+
downsample_scales:
|
| 47 |
+
- 4
|
| 48 |
+
- 4
|
| 49 |
+
- 4
|
| 50 |
+
- 4
|
| 51 |
+
- 1
|
| 52 |
+
in_channels: 1
|
| 53 |
+
kernel_sizes:
|
| 54 |
+
- 15
|
| 55 |
+
- 41
|
| 56 |
+
- 5
|
| 57 |
+
- 3
|
| 58 |
+
max_downsample_channels: 1024
|
| 59 |
+
max_groups: 16
|
| 60 |
+
nonlinear_activation: LeakyReLU
|
| 61 |
+
nonlinear_activation_params:
|
| 62 |
+
negative_slope: 0.1
|
| 63 |
+
out_channels: 1
|
| 64 |
+
scale_downsample_pooling: AvgPool1d
|
| 65 |
+
scale_downsample_pooling_params:
|
| 66 |
+
kernel_size: 4
|
| 67 |
+
padding: 2
|
| 68 |
+
stride: 2
|
| 69 |
+
scales: 3
|
| 70 |
+
discriminator_scheduler_params:
|
| 71 |
+
gamma: 0.5
|
| 72 |
+
milestones:
|
| 73 |
+
- 200000
|
| 74 |
+
- 400000
|
| 75 |
+
- 600000
|
| 76 |
+
- 800000
|
| 77 |
+
discriminator_scheduler_type: MultiStepLR
|
| 78 |
+
discriminator_train_start_steps: 0
|
| 79 |
+
discriminator_type: HiFiGANMultiScaleMultiPeriodDiscriminator
|
| 80 |
+
distributed: true
|
| 81 |
+
dropout_features: 0.0
|
| 82 |
+
eval_interval_steps: 100000
|
| 83 |
+
feat_match_loss_params:
|
| 84 |
+
average_by_discriminators: false
|
| 85 |
+
average_by_layers: false
|
| 86 |
+
include_final_outputs: false
|
| 87 |
+
frontend_mel_prediction_stop_steps: 200000
|
| 88 |
+
frontend_params:
|
| 89 |
+
conformer_params:
|
| 90 |
+
activation_type: swish
|
| 91 |
+
attention_dim: 184
|
| 92 |
+
attention_dropout_rate: 0.2
|
| 93 |
+
attention_heads: 2
|
| 94 |
+
cnn_module_kernel: 31
|
| 95 |
+
concat_after: false
|
| 96 |
+
dropout_rate: 0.2
|
| 97 |
+
linear_units: 1536
|
| 98 |
+
macaron_style: true
|
| 99 |
+
normalize_before: true
|
| 100 |
+
num_blocks: 2
|
| 101 |
+
pos_enc_layer_type: rel_pos
|
| 102 |
+
positional_dropout_rate: 0.2
|
| 103 |
+
positionwise_conv_kernel_size: 3
|
| 104 |
+
positionwise_layer_type: conv1d
|
| 105 |
+
selfattention_layer_type: rel_selfattn
|
| 106 |
+
use_cnn_module: true
|
| 107 |
+
prompt_channels: 1024
|
| 108 |
+
vqvec_channels: 512
|
| 109 |
+
generator_adv_loss_params:
|
| 110 |
+
average_by_discriminators: false
|
| 111 |
+
generator_grad_norm: -1
|
| 112 |
+
generator_optimizer_params:
|
| 113 |
+
betas:
|
| 114 |
+
- 0.5
|
| 115 |
+
- 0.9
|
| 116 |
+
lr: 0.0002
|
| 117 |
+
weight_decay: 0.0
|
| 118 |
+
generator_optimizer_type: Adam
|
| 119 |
+
generator_params:
|
| 120 |
+
bias: true
|
| 121 |
+
channels: 512
|
| 122 |
+
condition_dim: 1024
|
| 123 |
+
in_channels: 184
|
| 124 |
+
kernel_size: 7
|
| 125 |
+
nonlinear_activation: snakebeta-condition
|
| 126 |
+
out_channels: 1
|
| 127 |
+
resblock: '1'
|
| 128 |
+
resblock_dilations:
|
| 129 |
+
- - 1
|
| 130 |
+
- 3
|
| 131 |
+
- 5
|
| 132 |
+
- - 1
|
| 133 |
+
- 3
|
| 134 |
+
- 5
|
| 135 |
+
- - 1
|
| 136 |
+
- 3
|
| 137 |
+
- 5
|
| 138 |
+
resblock_kernel_sizes:
|
| 139 |
+
- 3
|
| 140 |
+
- 7
|
| 141 |
+
- 11
|
| 142 |
+
snake_logscale: true
|
| 143 |
+
upsample_kernel_sizes:
|
| 144 |
+
- 16
|
| 145 |
+
- 10
|
| 146 |
+
- 6
|
| 147 |
+
- 4
|
| 148 |
+
upsample_scales:
|
| 149 |
+
- 8
|
| 150 |
+
- 5
|
| 151 |
+
- 3
|
| 152 |
+
- 2
|
| 153 |
+
use_additional_convs: true
|
| 154 |
+
use_weight_norm: true
|
| 155 |
+
generator_scheduler_params:
|
| 156 |
+
gamma: 0.5
|
| 157 |
+
milestones:
|
| 158 |
+
- 200000
|
| 159 |
+
- 400000
|
| 160 |
+
- 600000
|
| 161 |
+
- 800000
|
| 162 |
+
generator_scheduler_type: MultiStepLR
|
| 163 |
+
generator_train_start_steps: 1
|
| 164 |
+
generator_type: BigVGAN
|
| 165 |
+
hop_size: 240
|
| 166 |
+
lambda_adv: 1.0
|
| 167 |
+
lambda_aux: 45.0
|
| 168 |
+
lambda_feat_match: 2.0
|
| 169 |
+
lambda_frontend_mel_prediction: 60
|
| 170 |
+
log_interval_steps: 1000
|
| 171 |
+
max_num_frames: 3000
|
| 172 |
+
mel_loss_params:
|
| 173 |
+
fft_size: 2048
|
| 174 |
+
fmax: 8000
|
| 175 |
+
fmin: 40
|
| 176 |
+
fs: 24000
|
| 177 |
+
hop_size: 300
|
| 178 |
+
log_base: null
|
| 179 |
+
num_mels: 80
|
| 180 |
+
win_length: 1200
|
| 181 |
+
window: hann
|
| 182 |
+
min_num_frames: 600
|
| 183 |
+
num_mels: 80
|
| 184 |
+
num_save_intermediate_results: 4
|
| 185 |
+
num_workers: 8
|
| 186 |
+
outdir: exp/train_all_ctxv2w.v1
|
| 187 |
+
pin_memory: true
|
| 188 |
+
pretrain: ''
|
| 189 |
+
prompt_fold_by_2: true
|
| 190 |
+
prompt_net_type: ConvPromptPrenet
|
| 191 |
+
rank: 0
|
| 192 |
+
sampling_rate: 24000
|
| 193 |
+
save_interval_steps: 10000
|
| 194 |
+
use_feat_match_loss: true
|
| 195 |
+
use_mel_loss: true
|
| 196 |
+
use_stft_loss: false
|
| 197 |
+
verbose: 1
|
| 198 |
+
version: 0.5.3
|
| 199 |
+
vq_codebook: feats/vqidx/codebook.npy
|
| 200 |
+
win_length: 697
|
| 201 |
+
world_size: 4
|
pretrained/generator.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6a10b9df62462bbf48382970ffba267b458b00b361bcb245701e3d3c0b6bd19f
|
| 3 |
+
size 161604549
|
pretrained/vq-wav2vec_kmeans.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c975a93479dc5f3cfc4339032e1547c6034eddd15eb1cba73364c20786b42a5a
|
| 3 |
+
size 336509919
|
requirements.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torchaudio==0.13.1
|
| 2 |
+
auraloss==0.4.0
|
| 3 |
+
cython==3.0.10
|
| 4 |
+
einops
|
| 5 |
+
debugpy==1.8.0
|
| 6 |
+
fairseq==0.12.2
|
| 7 |
+
filelock~=3.12.2
|
| 8 |
+
h5py
|
| 9 |
+
kaldiio~=2.18.0
|
| 10 |
+
librosa==0.8.1
|
| 11 |
+
matplotlib~=3.4.3
|
| 12 |
+
nltk==3.8.1
|
| 13 |
+
numpy
|
| 14 |
+
pathlib~=1.0.1
|
| 15 |
+
pyyaml~=6.0
|
| 16 |
+
scikit-learn
|
| 17 |
+
scipy~=1.7.1
|
| 18 |
+
setuptools==65.6.3
|
| 19 |
+
six==1.16.0
|
| 20 |
+
soundfile~=0.10.3.post1
|
| 21 |
+
sox
|
| 22 |
+
tensorboard
|
| 23 |
+
tensorboardx~=2.5.1
|
| 24 |
+
tqdm~=4.62.3
|
| 25 |
+
transformers==4.42.3
|
vec2wav2/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
__version__ = ""
|
vec2wav2/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (211 Bytes). View file
|
|
|
vec2wav2/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (289 Bytes). View file
|
|
|
vec2wav2/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (225 Bytes). View file
|
|
|
vec2wav2/bin/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
vec2wav2/bin/__init__.py
ADDED
|
File without changes
|
vec2wav2/bin/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (199 Bytes). View file
|
|
|
vec2wav2/bin/__pycache__/vc.cpython-310.pyc
ADDED
|
Binary file (4.76 kB). View file
|
|
|
vec2wav2/bin/decode.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Tomoki Hayashi
|
| 5 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 6 |
+
|
| 7 |
+
# Modified by Yiwei Guo, 2024
|
| 8 |
+
|
| 9 |
+
"""Decode with trained vec2wav Generator."""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import soundfile as sf
|
| 18 |
+
import torch
|
| 19 |
+
import yaml
|
| 20 |
+
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from vec2wav2.datasets import MelSCPDataset
|
| 24 |
+
from vec2wav2.utils import load_model, load_feat_codebook, idx2vec
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def set_loglevel(verbose):
|
| 28 |
+
# set logger
|
| 29 |
+
if verbose > 1:
|
| 30 |
+
logging.basicConfig(
|
| 31 |
+
level=logging.DEBUG,
|
| 32 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
| 33 |
+
)
|
| 34 |
+
elif verbose > 0:
|
| 35 |
+
logging.basicConfig(
|
| 36 |
+
level=logging.INFO,
|
| 37 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
| 38 |
+
)
|
| 39 |
+
else:
|
| 40 |
+
logging.basicConfig(
|
| 41 |
+
level=logging.WARN,
|
| 42 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
| 43 |
+
)
|
| 44 |
+
logging.warning("Skip DEBUG/INFO messages")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main():
|
| 48 |
+
"""Run decoding process."""
|
| 49 |
+
parser = argparse.ArgumentParser(
|
| 50 |
+
description="Decode from audio tokens and acoustic prompts with trained vec2wav model"
|
| 51 |
+
"(See detail in vec2wav2/bin/decode.py)."
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--feats-scp",
|
| 55 |
+
"--scp",
|
| 56 |
+
default=None,
|
| 57 |
+
type=str,
|
| 58 |
+
required=True,
|
| 59 |
+
help="kaldi-style feats.scp file. "
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--prompt-scp",
|
| 63 |
+
default=None,
|
| 64 |
+
type=str,
|
| 65 |
+
help="kaldi-style prompt.scp file. Similar to feats.scp."
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--outdir",
|
| 69 |
+
type=str,
|
| 70 |
+
required=True,
|
| 71 |
+
help="directory to save generated speech.",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--checkpoint",
|
| 75 |
+
type=str,
|
| 76 |
+
required=True,
|
| 77 |
+
help="checkpoint file to be loaded.",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--config",
|
| 81 |
+
default=None,
|
| 82 |
+
type=str,
|
| 83 |
+
help="yaml format configuration file. if not explicitly provided, "
|
| 84 |
+
"it will be searched in the checkpoint directory. (default=None)",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--verbose",
|
| 88 |
+
type=int,
|
| 89 |
+
default=1,
|
| 90 |
+
help="logging level. higher is more logging. (default=1)",
|
| 91 |
+
)
|
| 92 |
+
args = parser.parse_args()
|
| 93 |
+
set_loglevel(args.verbose)
|
| 94 |
+
|
| 95 |
+
# check directory existence
|
| 96 |
+
if not os.path.exists(args.outdir):
|
| 97 |
+
os.makedirs(args.outdir)
|
| 98 |
+
|
| 99 |
+
# load config
|
| 100 |
+
if args.config is None:
|
| 101 |
+
dirname = os.path.dirname(args.checkpoint)
|
| 102 |
+
args.config = os.path.join(dirname, "config.yml")
|
| 103 |
+
with open(args.config) as f:
|
| 104 |
+
config = yaml.load(f, Loader=yaml.Loader)
|
| 105 |
+
config.update(vars(args))
|
| 106 |
+
|
| 107 |
+
# get dataset
|
| 108 |
+
dataset = MelSCPDataset(
|
| 109 |
+
vqidx_scp=args.feats_scp,
|
| 110 |
+
prompt_scp=args.prompt_scp,
|
| 111 |
+
return_utt_id=True,
|
| 112 |
+
)
|
| 113 |
+
logging.info(f"The number of features to be decoded = {len(dataset)}.")
|
| 114 |
+
|
| 115 |
+
# setup model
|
| 116 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 117 |
+
logging.info(f"Using {'GPU' if torch.cuda.is_available() else 'CPU'}.")
|
| 118 |
+
|
| 119 |
+
model = load_model(args.checkpoint, config)
|
| 120 |
+
logging.info(f"Loaded model parameters from {args.checkpoint}.")
|
| 121 |
+
|
| 122 |
+
model.backend.remove_weight_norm()
|
| 123 |
+
model = model.eval().to(device)
|
| 124 |
+
|
| 125 |
+
# load vq codebook
|
| 126 |
+
feat_codebook, feat_codebook_numgroups = load_feat_codebook(np.load(config["vq_codebook"], allow_pickle=True), device)
|
| 127 |
+
|
| 128 |
+
# start generation
|
| 129 |
+
total_rtf = 0.0
|
| 130 |
+
with torch.no_grad(), tqdm(dataset, desc="[decode]") as pbar:
|
| 131 |
+
for idx, batch in enumerate(pbar, 1):
|
| 132 |
+
utt_id, vqidx, prompt = batch[0], batch[1], batch[2]
|
| 133 |
+
|
| 134 |
+
vqidx = torch.tensor(vqidx).to(device) # (L, G)
|
| 135 |
+
prompt = torch.tensor(prompt).unsqueeze(0).to(device) # (1, L', D')
|
| 136 |
+
|
| 137 |
+
vqidx = vqidx.long()
|
| 138 |
+
vqvec = idx2vec(feat_codebook, vqidx, feat_codebook_numgroups).unsqueeze(0) # (1, L, D)
|
| 139 |
+
|
| 140 |
+
# generate
|
| 141 |
+
start = time.time()
|
| 142 |
+
y = model.inference(vqvec, prompt)[-1].view(-1)
|
| 143 |
+
rtf = (time.time() - start) / (len(y) / config["sampling_rate"])
|
| 144 |
+
pbar.set_postfix({"RTF": rtf})
|
| 145 |
+
total_rtf += rtf
|
| 146 |
+
|
| 147 |
+
tgt_dir = os.path.dirname(os.path.join(config["outdir"], f"{utt_id}.wav"))
|
| 148 |
+
os.makedirs(tgt_dir, exist_ok=True)
|
| 149 |
+
basename = os.path.basename(f"{utt_id}.wav")
|
| 150 |
+
# save as PCM 16 bit wav file
|
| 151 |
+
sf.write(
|
| 152 |
+
os.path.join(tgt_dir, basename),
|
| 153 |
+
y.cpu().numpy(),
|
| 154 |
+
config["sampling_rate"],
|
| 155 |
+
"PCM_16",
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# report average RTF
|
| 159 |
+
logging.info(f"Finished generation of {idx} utterances (RTF = {total_rtf / idx:.03f}).")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
main()
|
vec2wav2/bin/gradio_app.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import logging
|
| 6 |
+
import yaml
|
| 7 |
+
import soundfile as sf
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from vec2wav2.bin.vc import VoiceConverter, configure_logging, vc_args
|
| 11 |
+
|
| 12 |
+
# Create Gradio interface
|
| 13 |
+
def create_interface():
|
| 14 |
+
args = vc_args()
|
| 15 |
+
logger = configure_logging(args.verbose)
|
| 16 |
+
voice_converter = VoiceConverter(
|
| 17 |
+
expdir=args.expdir,
|
| 18 |
+
token_extractor=args.token_extractor,
|
| 19 |
+
prompt_extractor=args.prompt_extractor,
|
| 20 |
+
prompt_output_layer=args.prompt_output_layer,
|
| 21 |
+
checkpoint=args.checkpoint,
|
| 22 |
+
script_logger=logger
|
| 23 |
+
)
|
| 24 |
+
with gr.Blocks(title="Voice Conversion") as demo:
|
| 25 |
+
gr.Markdown("# vec2wav 2.0 Voice Conversion Demo")
|
| 26 |
+
gr.Markdown("Upload source audio and target speaker audio to convert the voice.")
|
| 27 |
+
|
| 28 |
+
with gr.Row():
|
| 29 |
+
source_audio = gr.Audio(label="Source Audio", type="filepath")
|
| 30 |
+
target_audio = gr.Audio(label="Target Speaker Audio", type="filepath")
|
| 31 |
+
|
| 32 |
+
examples = [
|
| 33 |
+
["examples/Zuckerberg.wav", "examples/Rachel.wav"],
|
| 34 |
+
["examples/TheresaMay.wav", "examples/OptimusPrime.wav"]
|
| 35 |
+
]
|
| 36 |
+
gr.Examples(examples, label="Examples", inputs=[source_audio, target_audio])
|
| 37 |
+
|
| 38 |
+
convert_btn = gr.Button("Convert Voice")
|
| 39 |
+
output_audio = gr.Audio(label="Converted Audio")
|
| 40 |
+
|
| 41 |
+
convert_btn.click(
|
| 42 |
+
fn=voice_converter.voice_conversion,
|
| 43 |
+
inputs=[source_audio, target_audio],
|
| 44 |
+
outputs=output_audio
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
return demo
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
demo = create_interface()
|
| 51 |
+
demo.launch(share=True)
|
vec2wav2/bin/train.py
ADDED
|
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# Copyright 2019 Tomoki Hayashi
|
| 5 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 6 |
+
|
| 7 |
+
# Modified by Yiwei Guo, 2024
|
| 8 |
+
|
| 9 |
+
"""Train vec2wav."""
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
import random
|
| 16 |
+
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
|
| 19 |
+
import matplotlib
|
| 20 |
+
import numpy as np
|
| 21 |
+
import soundfile as sf
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import yaml
|
| 25 |
+
import torch.multiprocessing as mp
|
| 26 |
+
from tensorboardX import SummaryWriter
|
| 27 |
+
from torch.utils.data import DataLoader
|
| 28 |
+
from tqdm import tqdm
|
| 29 |
+
|
| 30 |
+
import vec2wav2
|
| 31 |
+
import vec2wav2.models
|
| 32 |
+
import vec2wav2.optimizers
|
| 33 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 34 |
+
|
| 35 |
+
from vec2wav2.datasets import AudioMelSCPDataset
|
| 36 |
+
from vec2wav2.layers import PQMF
|
| 37 |
+
from vec2wav2.losses import DiscriminatorAdversarialLoss
|
| 38 |
+
from vec2wav2.losses import FeatureMatchLoss
|
| 39 |
+
from vec2wav2.losses import GeneratorAdversarialLoss
|
| 40 |
+
from vec2wav2.losses import MelSpectrogramLoss
|
| 41 |
+
from vec2wav2.losses import MultiResolutionSTFTLoss
|
| 42 |
+
from vec2wav2.utils import crop_seq, load_feat_codebook, idx2vec
|
| 43 |
+
|
| 44 |
+
from vec2wav2.utils.espnet_utils import pad_list, make_non_pad_mask
|
| 45 |
+
|
| 46 |
+
# set to avoid matplotlib error in CLI environment
|
| 47 |
+
matplotlib.use("Agg")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def set_loglevel(verbose):
|
| 51 |
+
# set logger
|
| 52 |
+
if verbose > 1:
|
| 53 |
+
logging.basicConfig(
|
| 54 |
+
level=logging.DEBUG,
|
| 55 |
+
stream=sys.stdout,
|
| 56 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
| 57 |
+
)
|
| 58 |
+
elif verbose > 0:
|
| 59 |
+
logging.basicConfig(
|
| 60 |
+
level=logging.INFO,
|
| 61 |
+
stream=sys.stdout,
|
| 62 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
logging.basicConfig(
|
| 66 |
+
level=logging.WARN,
|
| 67 |
+
stream=sys.stdout,
|
| 68 |
+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
| 69 |
+
)
|
| 70 |
+
logging.warning("Skip DEBUG/INFO messages")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Trainer(object):
|
| 74 |
+
"""Customized trainer module for Parallel WaveGAN training."""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
steps,
|
| 79 |
+
epochs,
|
| 80 |
+
data_loader,
|
| 81 |
+
sampler,
|
| 82 |
+
model,
|
| 83 |
+
criterion,
|
| 84 |
+
optimizer,
|
| 85 |
+
scheduler,
|
| 86 |
+
config,
|
| 87 |
+
device=torch.device("cpu"),
|
| 88 |
+
):
|
| 89 |
+
"""Initialize trainer.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
steps (int): Initial global steps.
|
| 93 |
+
epochs (int): Initial global epochs.
|
| 94 |
+
data_loader (dict): Dict of data loaders. It must contain "train" and "dev" loaders.
|
| 95 |
+
model (dict): Dict of models. It must contain "generator" and "discriminator" models.
|
| 96 |
+
criterion (dict): Dict of criteria. It must contain "stft" and "mse" criteria.
|
| 97 |
+
optimizer (dict): Dict of optimizers. It must contain "generator" and "discriminator" optimizers.
|
| 98 |
+
scheduler (dict): Dict of schedulers. It must contain "generator" and "discriminator" schedulers.
|
| 99 |
+
config (dict): Config dict loaded from yaml format configuration file.
|
| 100 |
+
device (torch.deive): Pytorch device instance.
|
| 101 |
+
|
| 102 |
+
"""
|
| 103 |
+
self.steps = steps
|
| 104 |
+
self.epochs = epochs
|
| 105 |
+
self.data_loader = data_loader
|
| 106 |
+
self.sampler = sampler
|
| 107 |
+
self.model = model
|
| 108 |
+
self.criterion = criterion
|
| 109 |
+
self.optimizer = optimizer
|
| 110 |
+
self.scheduler = scheduler
|
| 111 |
+
self.config = config
|
| 112 |
+
self.device = device
|
| 113 |
+
self.writer = SummaryWriter(config["outdir"])
|
| 114 |
+
self.finish_train = False
|
| 115 |
+
self.total_train_loss = defaultdict(float)
|
| 116 |
+
self.total_eval_loss = defaultdict(float)
|
| 117 |
+
|
| 118 |
+
# load vq codebook
|
| 119 |
+
feat_codebook_path = self.config["vq_codebook"]
|
| 120 |
+
|
| 121 |
+
self.feat_codebook, self.feat_codebook_numgroups = load_feat_codebook(np.load(feat_codebook_path, allow_pickle=True), device)
|
| 122 |
+
|
| 123 |
+
def run(self):
|
| 124 |
+
"""Run training."""
|
| 125 |
+
self.tqdm = tqdm(initial=self.steps, total=self.config["train_max_steps"], desc="[train]")
|
| 126 |
+
while True:
|
| 127 |
+
# train one epoch
|
| 128 |
+
self._train_epoch()
|
| 129 |
+
|
| 130 |
+
# check whether training is finished
|
| 131 |
+
if self.finish_train:
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
self.tqdm.close()
|
| 135 |
+
logging.info("Finished training.")
|
| 136 |
+
|
| 137 |
+
def save_checkpoint(self, checkpoint_path):
|
| 138 |
+
"""Save checkpoint.
|
| 139 |
+
Args:
|
| 140 |
+
checkpoint_path (str): Checkpoint path to be saved.
|
| 141 |
+
"""
|
| 142 |
+
state_dict = {
|
| 143 |
+
"optimizer": {
|
| 144 |
+
"generator": self.optimizer["generator"].state_dict(),
|
| 145 |
+
"discriminator": self.optimizer["discriminator"].state_dict(),
|
| 146 |
+
},
|
| 147 |
+
"scheduler": {
|
| 148 |
+
"generator": self.scheduler["generator"].state_dict(),
|
| 149 |
+
"discriminator": self.scheduler["discriminator"].state_dict(),
|
| 150 |
+
},
|
| 151 |
+
"steps": self.steps,
|
| 152 |
+
"epochs": self.epochs,
|
| 153 |
+
}
|
| 154 |
+
if self.config["distributed"]:
|
| 155 |
+
state_dict["model"] = {
|
| 156 |
+
"generator": self.model["generator"].module.state_dict(),
|
| 157 |
+
"discriminator": self.model["discriminator"].module.state_dict(),
|
| 158 |
+
}
|
| 159 |
+
else:
|
| 160 |
+
state_dict["model"] = {
|
| 161 |
+
"generator": self.model["generator"].state_dict(),
|
| 162 |
+
"discriminator": self.model["discriminator"].state_dict(),
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
if not os.path.exists(os.path.dirname(checkpoint_path)):
|
| 166 |
+
os.makedirs(os.path.dirname(checkpoint_path))
|
| 167 |
+
torch.save(state_dict, checkpoint_path)
|
| 168 |
+
|
| 169 |
+
def load_checkpoint(self, checkpoint_path, load_only_params=False):
|
| 170 |
+
"""Load checkpoint.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
checkpoint_path (str): Checkpoint path to be loaded.
|
| 174 |
+
load_only_params (bool): Whether to load only model parameters.
|
| 175 |
+
|
| 176 |
+
"""
|
| 177 |
+
state_dict = torch.load(checkpoint_path, map_location="cpu")
|
| 178 |
+
if self.config["distributed"]:
|
| 179 |
+
self.model["generator"].module.load_state_dict(
|
| 180 |
+
state_dict["model"]["generator"]
|
| 181 |
+
)
|
| 182 |
+
self.model["discriminator"].module.load_state_dict(
|
| 183 |
+
state_dict["model"]["discriminator"]
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
self.model["generator"].load_state_dict(state_dict["model"]["generator"])
|
| 187 |
+
self.model["discriminator"].load_state_dict(
|
| 188 |
+
state_dict["model"]["discriminator"]
|
| 189 |
+
)
|
| 190 |
+
if not load_only_params:
|
| 191 |
+
self.steps = state_dict["steps"]
|
| 192 |
+
self.epochs = state_dict["epochs"]
|
| 193 |
+
self.optimizer["generator"].load_state_dict(state_dict["optimizer"]["generator"])
|
| 194 |
+
self.optimizer["discriminator"].load_state_dict(state_dict["optimizer"]["discriminator"])
|
| 195 |
+
self.scheduler["generator"].load_state_dict(state_dict["scheduler"]["generator"])
|
| 196 |
+
self.scheduler["discriminator"].load_state_dict(state_dict["scheduler"]["discriminator"])
|
| 197 |
+
|
| 198 |
+
def _train_step(self, batch):
|
| 199 |
+
"""Train model one step."""
|
| 200 |
+
# parse batch
|
| 201 |
+
vqidx, mel, prompt, y, xlens, prompt_lens = batch
|
| 202 |
+
vqidx = vqidx.to(self.device)
|
| 203 |
+
mel = mel.to(self.device)
|
| 204 |
+
prompt = prompt.to(self.device)
|
| 205 |
+
vqvec = idx2vec(self.feat_codebook, vqidx, self.feat_codebook_numgroups) # (B, L, D)
|
| 206 |
+
y = y.unsqueeze(-2).to(self.device) # (B, 1, T)
|
| 207 |
+
|
| 208 |
+
# build mask
|
| 209 |
+
mask = make_non_pad_mask(xlens).to(self.device) # (B, L)
|
| 210 |
+
prompt_mask = make_non_pad_mask(prompt_lens).to(self.device) # (B, L_prompt)
|
| 211 |
+
|
| 212 |
+
# crop wav sequence
|
| 213 |
+
crop_xlen = min(self.config["crop_max_frames"], min(xlens))
|
| 214 |
+
x_offsets = [np.random.randint(0, l - crop_xlen + 1) for l in xlens]
|
| 215 |
+
crop_ylen = crop_xlen * self.config["hop_size"]
|
| 216 |
+
y_offsets = [o * self.config["hop_size"] for o in x_offsets]
|
| 217 |
+
y = crop_seq(y, y_offsets, crop_ylen)
|
| 218 |
+
|
| 219 |
+
#######################
|
| 220 |
+
# Generator #
|
| 221 |
+
#######################
|
| 222 |
+
if self.steps > self.config.get("generator_train_start_steps", 0):
|
| 223 |
+
mel_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask, crop_xlen, x_offsets) # (B, L, 80), (B, C, T)
|
| 224 |
+
|
| 225 |
+
# initialize
|
| 226 |
+
gen_loss, aux_loss = 0.0, 0.0
|
| 227 |
+
|
| 228 |
+
# frontend mel prediction loss
|
| 229 |
+
if self.steps <= self.config.get("frontend_mel_prediction_stop_steps", 0):
|
| 230 |
+
frontend_mel_pred_loss = F.l1_loss(torch.masked_select(mel, mask.unsqueeze(-1)),
|
| 231 |
+
torch.masked_select(mel_, mask.unsqueeze(-1)))
|
| 232 |
+
self.total_train_loss["train/frontend_mel_pred_loss"] += frontend_mel_pred_loss.item()
|
| 233 |
+
gen_loss += self.config["lambda_frontend_mel_prediction"] * frontend_mel_pred_loss
|
| 234 |
+
|
| 235 |
+
# multi-resolution sfft loss
|
| 236 |
+
if self.config["use_stft_loss"]:
|
| 237 |
+
sc_loss, mag_loss = self.criterion["stft"](y_, y)
|
| 238 |
+
aux_loss += sc_loss + mag_loss
|
| 239 |
+
self.total_train_loss["train/spectral_convergence_loss"] += sc_loss.item()
|
| 240 |
+
self.total_train_loss["train/log_stft_magnitude_loss"] += mag_loss.item()
|
| 241 |
+
|
| 242 |
+
# subband multi-resolution stft loss
|
| 243 |
+
if self.config["use_subband_stft_loss"]:
|
| 244 |
+
aux_loss *= 0.5 # for balancing with subband stft loss
|
| 245 |
+
y_mb = self.criterion["pqmf"].analysis(y)
|
| 246 |
+
y_mb_ = self.criterion["pqmf"].analysis(y_)
|
| 247 |
+
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
|
| 248 |
+
aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
|
| 249 |
+
self.total_train_loss["train/sub_spectral_convergence_loss"] += sub_sc_loss.item()
|
| 250 |
+
self.total_train_loss["train/sub_log_stft_magnitude_loss"] += sub_mag_loss.item()
|
| 251 |
+
|
| 252 |
+
# mel spectrogram loss
|
| 253 |
+
if self.config["use_mel_loss"]:
|
| 254 |
+
mel_loss = self.criterion["mel"](y_, y)
|
| 255 |
+
aux_loss += mel_loss
|
| 256 |
+
self.total_train_loss["train/mel_loss"] += mel_loss.item()
|
| 257 |
+
|
| 258 |
+
# weighting aux loss
|
| 259 |
+
gen_loss += self.config.get("lambda_aux", 1.0) * aux_loss
|
| 260 |
+
|
| 261 |
+
# adversarial loss
|
| 262 |
+
if self.steps > self.config["discriminator_train_start_steps"]:
|
| 263 |
+
p_ = self.model["discriminator"](y_)
|
| 264 |
+
adv_loss = self.criterion["gen_adv"](p_)
|
| 265 |
+
self.total_train_loss["train/adversarial_loss"] += adv_loss.item()
|
| 266 |
+
|
| 267 |
+
# feature matching loss
|
| 268 |
+
if self.config["use_feat_match_loss"]:
|
| 269 |
+
# no need to track gradients
|
| 270 |
+
with torch.no_grad():
|
| 271 |
+
p = self.model["discriminator"](y)
|
| 272 |
+
fm_loss = self.criterion["feat_match"](p_, p)
|
| 273 |
+
self.total_train_loss["train/feature_matching_loss"] += fm_loss.item()
|
| 274 |
+
adv_loss += self.config["lambda_feat_match"] * fm_loss
|
| 275 |
+
|
| 276 |
+
# add adversarial loss to generator loss
|
| 277 |
+
gen_loss += self.config["lambda_adv"] * adv_loss
|
| 278 |
+
|
| 279 |
+
self.total_train_loss["train/generator_loss"] += gen_loss.item()
|
| 280 |
+
|
| 281 |
+
# update generator
|
| 282 |
+
self.optimizer["generator"].zero_grad()
|
| 283 |
+
gen_loss.backward()
|
| 284 |
+
if self.config["generator_grad_norm"] > 0:
|
| 285 |
+
torch.nn.utils.clip_grad_norm_(
|
| 286 |
+
self.model["generator"].parameters(),
|
| 287 |
+
self.config["generator_grad_norm"],
|
| 288 |
+
)
|
| 289 |
+
self.optimizer["generator"].step()
|
| 290 |
+
self.scheduler["generator"].step()
|
| 291 |
+
|
| 292 |
+
#######################
|
| 293 |
+
# Discriminator #
|
| 294 |
+
#######################
|
| 295 |
+
if self.steps > self.config["discriminator_train_start_steps"]:
|
| 296 |
+
# re-compute y_ which leads better quality
|
| 297 |
+
with torch.no_grad():
|
| 298 |
+
# logging.info(f"{vqvec.shape, prompt.shape, mask.shape, prompt_mask.shape}")
|
| 299 |
+
_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask, crop_xlen, x_offsets) # (B, L, 80), (B, C, T)
|
| 300 |
+
|
| 301 |
+
if self.config["generator_params"]["out_channels"] > 1:
|
| 302 |
+
y_ = self.criterion["pqmf"].synthesis(y_)
|
| 303 |
+
|
| 304 |
+
# discriminator loss
|
| 305 |
+
p = self.model["discriminator"](y)
|
| 306 |
+
p_ = self.model["discriminator"](y_.detach())
|
| 307 |
+
real_loss, fake_loss = self.criterion["dis_adv"](p_, p)
|
| 308 |
+
dis_loss = real_loss + fake_loss
|
| 309 |
+
self.total_train_loss["train/real_loss"] += real_loss.item()
|
| 310 |
+
self.total_train_loss["train/fake_loss"] += fake_loss.item()
|
| 311 |
+
self.total_train_loss["train/discriminator_loss"] += dis_loss.item()
|
| 312 |
+
|
| 313 |
+
# update discriminator
|
| 314 |
+
self.optimizer["discriminator"].zero_grad()
|
| 315 |
+
dis_loss.backward()
|
| 316 |
+
if self.config["discriminator_grad_norm"] > 0:
|
| 317 |
+
torch.nn.utils.clip_grad_norm_(
|
| 318 |
+
self.model["discriminator"].parameters(),
|
| 319 |
+
self.config["discriminator_grad_norm"],
|
| 320 |
+
)
|
| 321 |
+
self.optimizer["discriminator"].step()
|
| 322 |
+
self.scheduler["discriminator"].step()
|
| 323 |
+
|
| 324 |
+
# update counts
|
| 325 |
+
self.steps += 1
|
| 326 |
+
self.tqdm.update(1)
|
| 327 |
+
self._check_train_finish()
|
| 328 |
+
|
| 329 |
+
def _train_epoch(self):
|
| 330 |
+
"""Train model one epoch."""
|
| 331 |
+
for train_steps_per_epoch, batch in enumerate(self.data_loader["train"], 1):
|
| 332 |
+
# train one step
|
| 333 |
+
self._train_step(batch)
|
| 334 |
+
|
| 335 |
+
# check interval
|
| 336 |
+
if self.config["rank"] == 0:
|
| 337 |
+
self._check_log_interval()
|
| 338 |
+
self._check_eval_interval()
|
| 339 |
+
self._check_save_interval()
|
| 340 |
+
|
| 341 |
+
# check whether training is finished
|
| 342 |
+
if self.finish_train:
|
| 343 |
+
return
|
| 344 |
+
|
| 345 |
+
# update
|
| 346 |
+
self.epochs += 1
|
| 347 |
+
self.train_steps_per_epoch = train_steps_per_epoch
|
| 348 |
+
logging.info(
|
| 349 |
+
f"(Steps: {self.steps}) Finished {self.epochs} epoch training "
|
| 350 |
+
f"({self.train_steps_per_epoch} steps per epoch)."
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# needed for shuffle in distributed training
|
| 354 |
+
if self.config["distributed"]:
|
| 355 |
+
self.sampler["train"].set_epoch(self.epochs)
|
| 356 |
+
|
| 357 |
+
@torch.no_grad()
|
| 358 |
+
def _eval_step(self, batch):
|
| 359 |
+
"""Evaluate model one step."""
|
| 360 |
+
# parse batch
|
| 361 |
+
vqidx, mel, prompt, y, xlens, prompt_lens = batch
|
| 362 |
+
vqidx = vqidx.to(self.device).long()
|
| 363 |
+
mel = mel.to(self.device)
|
| 364 |
+
prompt = prompt.to(self.device)
|
| 365 |
+
vqvec = idx2vec(self.feat_codebook, vqidx, self.feat_codebook_numgroups)
|
| 366 |
+
y = y.unsqueeze(-2).to(self.device) # (B, 1, T)
|
| 367 |
+
|
| 368 |
+
# build mask
|
| 369 |
+
mask = make_non_pad_mask(xlens).to(self.device) # (B, L)
|
| 370 |
+
prompt_mask = make_non_pad_mask(prompt_lens).to(self.device) # (B, L_prompt)
|
| 371 |
+
|
| 372 |
+
#######################
|
| 373 |
+
# Generator #
|
| 374 |
+
#######################
|
| 375 |
+
mel_, _, y_ = self.model["generator"](vqvec, prompt, mask, prompt_mask) # (B, L, 80), (B, C, T)
|
| 376 |
+
|
| 377 |
+
# reconstruct the signal from multi-band signal
|
| 378 |
+
if self.config["generator_params"]["out_channels"] > 1:
|
| 379 |
+
y_mb_ = y_
|
| 380 |
+
y_ = self.criterion["pqmf"].synthesis(y_mb_)
|
| 381 |
+
|
| 382 |
+
# initialize
|
| 383 |
+
gen_loss = 0.0
|
| 384 |
+
aux_loss = 0.0
|
| 385 |
+
|
| 386 |
+
# frontend mel prediction loss
|
| 387 |
+
frontend_mel_pred_loss = F.l1_loss(torch.masked_select(mel, mask.unsqueeze(-1)),
|
| 388 |
+
torch.masked_select(mel_, mask.unsqueeze(-1)))
|
| 389 |
+
self.total_eval_loss["eval/frontend_mel_pred_loss"] += frontend_mel_pred_loss.item()
|
| 390 |
+
gen_loss += self.config["lambda_frontend_mel_prediction"] * frontend_mel_pred_loss
|
| 391 |
+
|
| 392 |
+
# multi-resolution stft loss
|
| 393 |
+
if self.config["use_stft_loss"]:
|
| 394 |
+
sc_loss, mag_loss = self.criterion["stft"](y_, y)
|
| 395 |
+
aux_loss += sc_loss + mag_loss
|
| 396 |
+
self.total_eval_loss["eval/spectral_convergence_loss"] += sc_loss.item()
|
| 397 |
+
self.total_eval_loss["eval/log_stft_magnitude_loss"] += mag_loss.item()
|
| 398 |
+
|
| 399 |
+
# subband multi-resolution stft loss
|
| 400 |
+
if self.config.get("use_subband_stft_loss", False):
|
| 401 |
+
aux_loss *= 0.5 # for balancing with subband stft loss
|
| 402 |
+
y_mb = self.criterion["pqmf"].analysis(y)
|
| 403 |
+
sub_sc_loss, sub_mag_loss = self.criterion["sub_stft"](y_mb_, y_mb)
|
| 404 |
+
self.total_eval_loss["eval/sub_spectral_convergence_loss"] += sub_sc_loss.item()
|
| 405 |
+
self.total_eval_loss["eval/sub_log_stft_magnitude_loss"] += sub_mag_loss.item()
|
| 406 |
+
aux_loss += 0.5 * (sub_sc_loss + sub_mag_loss)
|
| 407 |
+
|
| 408 |
+
# mel spectrogram loss
|
| 409 |
+
if self.config["use_mel_loss"]:
|
| 410 |
+
mel_loss = self.criterion["mel"](y_, y)
|
| 411 |
+
aux_loss += mel_loss
|
| 412 |
+
self.total_eval_loss["eval/mel_loss"] += mel_loss.item()
|
| 413 |
+
|
| 414 |
+
# weighting stft loss
|
| 415 |
+
gen_loss += aux_loss * self.config.get("lambda_aux", 1.0)
|
| 416 |
+
|
| 417 |
+
# adversarial loss
|
| 418 |
+
p_ = self.model["discriminator"](y_)
|
| 419 |
+
adv_loss = self.criterion["gen_adv"](p_)
|
| 420 |
+
gen_loss += self.config["lambda_adv"] * adv_loss
|
| 421 |
+
|
| 422 |
+
# feature matching loss
|
| 423 |
+
if self.config["use_feat_match_loss"]:
|
| 424 |
+
p = self.model["discriminator"](y)
|
| 425 |
+
fm_loss = self.criterion["feat_match"](p_, p)
|
| 426 |
+
self.total_eval_loss["eval/feature_matching_loss"] += fm_loss.item()
|
| 427 |
+
gen_loss += (
|
| 428 |
+
self.config["lambda_adv"] * self.config["lambda_feat_match"] * fm_loss
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
#######################
|
| 432 |
+
# Discriminator #
|
| 433 |
+
#######################
|
| 434 |
+
p = self.model["discriminator"](y)
|
| 435 |
+
p_ = self.model["discriminator"](y_)
|
| 436 |
+
|
| 437 |
+
# discriminator loss
|
| 438 |
+
real_loss, fake_loss = self.criterion["dis_adv"](p_, p)
|
| 439 |
+
dis_loss = real_loss + fake_loss
|
| 440 |
+
|
| 441 |
+
# add to total eval loss
|
| 442 |
+
self.total_eval_loss["eval/adversarial_loss"] += adv_loss.item()
|
| 443 |
+
self.total_eval_loss["eval/generator_loss"] += gen_loss.item()
|
| 444 |
+
self.total_eval_loss["eval/real_loss"] += real_loss.item()
|
| 445 |
+
self.total_eval_loss["eval/fake_loss"] += fake_loss.item()
|
| 446 |
+
self.total_eval_loss["eval/discriminator_loss"] += dis_loss.item()
|
| 447 |
+
|
| 448 |
+
def _eval_epoch(self):
|
| 449 |
+
"""Evaluate model one epoch."""
|
| 450 |
+
logging.info(f"(Steps: {self.steps}) Start evaluation.")
|
| 451 |
+
# change mode
|
| 452 |
+
for key in self.model.keys():
|
| 453 |
+
self.model[key].eval()
|
| 454 |
+
|
| 455 |
+
# calculate loss for each batch
|
| 456 |
+
for eval_steps_per_epoch, batch in enumerate(tqdm(self.data_loader["dev"], desc="[eval]"), 1):
|
| 457 |
+
# eval one step
|
| 458 |
+
self._eval_step(batch)
|
| 459 |
+
|
| 460 |
+
logging.info(
|
| 461 |
+
f"(Steps: {self.steps}) Finished evaluation "
|
| 462 |
+
f"({eval_steps_per_epoch} steps per epoch)."
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
# average loss
|
| 466 |
+
for key in self.total_eval_loss.keys():
|
| 467 |
+
self.total_eval_loss[key] /= eval_steps_per_epoch
|
| 468 |
+
logging.info(f"(Steps: {self.steps}) {key} = {self.total_eval_loss[key]:.4f}.")
|
| 469 |
+
|
| 470 |
+
# record
|
| 471 |
+
self._write_to_tensorboard(self.total_eval_loss)
|
| 472 |
+
|
| 473 |
+
# reset
|
| 474 |
+
self.total_eval_loss = defaultdict(float)
|
| 475 |
+
|
| 476 |
+
# restore mode
|
| 477 |
+
for key in self.model.keys():
|
| 478 |
+
self.model[key].train()
|
| 479 |
+
|
| 480 |
+
def _write_to_tensorboard(self, loss):
|
| 481 |
+
"""Write to tensorboard."""
|
| 482 |
+
for key, value in loss.items():
|
| 483 |
+
self.writer.add_scalar(key, value, self.steps)
|
| 484 |
+
|
| 485 |
+
def _check_save_interval(self):
|
| 486 |
+
if self.steps % self.config["save_interval_steps"] == 0:
|
| 487 |
+
self.save_checkpoint(os.path.join(self.config["outdir"],
|
| 488 |
+
f"checkpoint-{self.steps}steps.pkl"))
|
| 489 |
+
logging.info(f"Successfully saved checkpoint @ {self.steps} steps.")
|
| 490 |
+
|
| 491 |
+
def _check_eval_interval(self):
|
| 492 |
+
if self.steps % self.config["eval_interval_steps"] == 0:
|
| 493 |
+
self._eval_epoch()
|
| 494 |
+
|
| 495 |
+
def _check_log_interval(self):
|
| 496 |
+
if self.steps % self.config["log_interval_steps"] == 0:
|
| 497 |
+
for key in self.total_train_loss.keys():
|
| 498 |
+
self.total_train_loss[key] /= self.config["log_interval_steps"]
|
| 499 |
+
logging.info(f"(Steps: {self.steps}) {key} = {self.total_train_loss[key]:.4f}.")
|
| 500 |
+
self._write_to_tensorboard(self.total_train_loss)
|
| 501 |
+
|
| 502 |
+
# reset
|
| 503 |
+
self.total_train_loss = defaultdict(float)
|
| 504 |
+
|
| 505 |
+
def _check_train_finish(self):
|
| 506 |
+
if self.steps >= self.config["train_max_steps"]:
|
| 507 |
+
self.finish_train = True
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class Collator(object):
|
| 511 |
+
"""Customized collator for Pytorch DataLoader in training."""
|
| 512 |
+
|
| 513 |
+
def __init__(
|
| 514 |
+
self,
|
| 515 |
+
hop_size=256,
|
| 516 |
+
win_length=1024,
|
| 517 |
+
sampling_rate=16000,
|
| 518 |
+
prompt_dim=1024,
|
| 519 |
+
prompt_fold_by_2=False
|
| 520 |
+
):
|
| 521 |
+
"""Initialize customized collator for PyTorch DataLoader.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
hop_size (int): Hop size of features, in sampling points.
|
| 525 |
+
win_length (int): window length of features.
|
| 526 |
+
sampling_rate (int): sampling rate of waveform data
|
| 527 |
+
prompt_dim (int): number of prompt embedding dimensions
|
| 528 |
+
"""
|
| 529 |
+
self.hop_size = hop_size
|
| 530 |
+
self.win_length = win_length
|
| 531 |
+
self.sampling_rate = sampling_rate
|
| 532 |
+
self.prompt_dim = prompt_dim
|
| 533 |
+
if prompt_fold_by_2:
|
| 534 |
+
self.prompt_len_factor = 2
|
| 535 |
+
else:
|
| 536 |
+
self.prompt_len_factor = 1
|
| 537 |
+
|
| 538 |
+
def construct_prompt(self, mel_lens):
|
| 539 |
+
prompt_lens = [random.randint(int(l / (3 * self.prompt_len_factor)), int(l / (2 * self.prompt_len_factor))) for l in mel_lens]
|
| 540 |
+
prompt_starts = []
|
| 541 |
+
is_from_start = []
|
| 542 |
+
for ml, pl in zip(mel_lens, prompt_lens):
|
| 543 |
+
if random.random() > 0.5:
|
| 544 |
+
# from start
|
| 545 |
+
prompt_start = random.randint(0, 1 * self.sampling_rate // (self.hop_size * self.prompt_len_factor))
|
| 546 |
+
is_from_start.append(True)
|
| 547 |
+
else:
|
| 548 |
+
# from ending
|
| 549 |
+
prompt_start = random.randint((ml - 1 * self.sampling_rate // self.hop_size) // self.prompt_len_factor, ml // self.prompt_len_factor) - pl
|
| 550 |
+
is_from_start.append(False)
|
| 551 |
+
prompt_starts.append(prompt_start)
|
| 552 |
+
return prompt_lens, prompt_starts, is_from_start
|
| 553 |
+
|
| 554 |
+
def __call__(self, batch):
|
| 555 |
+
"""Convert into batch tensors.
|
| 556 |
+
|
| 557 |
+
Args:
|
| 558 |
+
batch (list): list of tuple of the pair of audio and features.
|
| 559 |
+
|
| 560 |
+
This collator will automatically determine the prompt segment (acoustic context) for each utterance.
|
| 561 |
+
The prompt is cut off from the current utterance, ranging from one third to half of the original utterance.
|
| 562 |
+
The prompt can be cut from either the starting or the ending of the utterance, within 1 second margin.
|
| 563 |
+
The other features include 2-dim VQ features (2 is the number of groups), and D-dim prompts (e.g. WavLM features)
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
Tensor ys: waveform batch (B, T).
|
| 567 |
+
Tensors vqs, mels: Auxiliary feature batch (B, C, T'), where T' = T / hop_size.
|
| 568 |
+
Tensor prompts: prompt feature batch (B, C, T'')
|
| 569 |
+
List c_lengths, prompt_lengths: list of lengths
|
| 570 |
+
"""
|
| 571 |
+
batch = batch[0]
|
| 572 |
+
|
| 573 |
+
# check length
|
| 574 |
+
batch = [self._adjust_length(*b) for b in batch]
|
| 575 |
+
ys, vqs, mels, prompts_old = list(map(list, zip(*batch))) # [(a,b), (c,d)] -> [a, c], [b, d]
|
| 576 |
+
|
| 577 |
+
batch_size = len(vqs)
|
| 578 |
+
|
| 579 |
+
prompt_lengths, prompt_starts, is_from_starts = self.construct_prompt([len(m) for m in mels])
|
| 580 |
+
c_lengths = []
|
| 581 |
+
prompts = torch.zeros(batch_size, max(prompt_lengths), self.prompt_dim)
|
| 582 |
+
for i in range(batch_size):
|
| 583 |
+
prompts[i, :prompt_lengths[i]] = torch.tensor(prompts_old[i][prompt_starts[i]:prompt_starts[i]+prompt_lengths[i], :])
|
| 584 |
+
if is_from_starts[i]:
|
| 585 |
+
start_idx = (prompt_starts[i] + prompt_lengths[i])*self.prompt_len_factor
|
| 586 |
+
mels[i] = mels[i][start_idx:]
|
| 587 |
+
vqs[i] = vqs[i][start_idx:]
|
| 588 |
+
ys[i] = ys[i][start_idx * self.hop_size: ]
|
| 589 |
+
else:
|
| 590 |
+
end_idx = prompt_starts[i]*self.prompt_len_factor
|
| 591 |
+
mels[i] = mels[i][:end_idx]
|
| 592 |
+
vqs[i] = vqs[i][:end_idx]
|
| 593 |
+
ys[i] = ys[i][:end_idx * self.hop_size]
|
| 594 |
+
c_lengths.append(len(mels[i]))
|
| 595 |
+
|
| 596 |
+
vqs = pad_list([torch.tensor(c) for c in vqs], pad_value=0) # (B, L, Groups)
|
| 597 |
+
vqs = vqs.long()
|
| 598 |
+
mels = pad_list([torch.tensor(c) for c in mels], pad_value=0) # (B, L, 80)
|
| 599 |
+
|
| 600 |
+
ys = pad_list([torch.tensor(y, dtype=torch.float) for y in ys], pad_value=0)[:, :mels.size(1) * self.hop_size] # (B, T)
|
| 601 |
+
assert ys.size(1) == mels.size(1) * self.hop_size == vqs.size(1) * self.hop_size
|
| 602 |
+
|
| 603 |
+
return vqs, mels, prompts, ys, c_lengths, prompt_lengths
|
| 604 |
+
|
| 605 |
+
def _adjust_length(self, x, c, *args):
|
| 606 |
+
"""Adjust the audio and feature lengths.
|
| 607 |
+
|
| 608 |
+
Note:
|
| 609 |
+
Basically we assume that the length of x and c are adjusted
|
| 610 |
+
through preprocessing stage, but if we use other library processed
|
| 611 |
+
features, this process will be needed.
|
| 612 |
+
|
| 613 |
+
"""
|
| 614 |
+
if len(x) > len(c) * self.hop_size:
|
| 615 |
+
x = x[(self.win_length - self.hop_size) // 2:]
|
| 616 |
+
x = x[:len(c) * self.hop_size]
|
| 617 |
+
|
| 618 |
+
# check the legnth is valid
|
| 619 |
+
assert len(x) == len(c) * self.hop_size
|
| 620 |
+
|
| 621 |
+
return x, c, *args
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def main(rank, n_gpus):
|
| 625 |
+
"""Run training process."""
|
| 626 |
+
parser = argparse.ArgumentParser(
|
| 627 |
+
description="Train vec2wav2 (See detail in vec2wav2/bin/train.py)."
|
| 628 |
+
)
|
| 629 |
+
parser.add_argument(
|
| 630 |
+
"--train-wav-scp",
|
| 631 |
+
default=None,
|
| 632 |
+
type=str,
|
| 633 |
+
help="kaldi-style wav.scp file for training. "
|
| 634 |
+
)
|
| 635 |
+
parser.add_argument(
|
| 636 |
+
"--train-vqidx-scp",
|
| 637 |
+
default=None,
|
| 638 |
+
type=str,
|
| 639 |
+
help="kaldi-style feats.scp file for training. "
|
| 640 |
+
)
|
| 641 |
+
parser.add_argument(
|
| 642 |
+
"--train-mel-scp",
|
| 643 |
+
default=None,
|
| 644 |
+
type=str,
|
| 645 |
+
help="kaldi-style feats.scp file for training. "
|
| 646 |
+
)
|
| 647 |
+
parser.add_argument(
|
| 648 |
+
"--train-prompt-scp",
|
| 649 |
+
default=None,
|
| 650 |
+
type=str,
|
| 651 |
+
help="prompt scp (in this case, utt to path)"
|
| 652 |
+
)
|
| 653 |
+
parser.add_argument(
|
| 654 |
+
"--train-segments",
|
| 655 |
+
default=None,
|
| 656 |
+
type=str,
|
| 657 |
+
help="kaldi-style segments file for training.",
|
| 658 |
+
)
|
| 659 |
+
parser.add_argument(
|
| 660 |
+
"--train-num-frames",
|
| 661 |
+
default=None,
|
| 662 |
+
type=str,
|
| 663 |
+
help="kaldi-style utt2num_frames file for training.",
|
| 664 |
+
)
|
| 665 |
+
parser.add_argument(
|
| 666 |
+
"--dev-wav-scp",
|
| 667 |
+
default=None,
|
| 668 |
+
type=str,
|
| 669 |
+
help="kaldi-style wav.scp file for validation. "
|
| 670 |
+
)
|
| 671 |
+
parser.add_argument(
|
| 672 |
+
"--dev-vqidx-scp",
|
| 673 |
+
default=None,
|
| 674 |
+
type=str,
|
| 675 |
+
help="kaldi-style feats.scp file for vaidation. "
|
| 676 |
+
)
|
| 677 |
+
parser.add_argument(
|
| 678 |
+
"--dev-mel-scp",
|
| 679 |
+
default=None,
|
| 680 |
+
type=str,
|
| 681 |
+
help="kaldi-style feats.scp file for vaidation. "
|
| 682 |
+
)
|
| 683 |
+
parser.add_argument(
|
| 684 |
+
"--dev-prompt-scp",
|
| 685 |
+
default=None,
|
| 686 |
+
type=str,
|
| 687 |
+
help="prompt scp (in this case, utt to path)"
|
| 688 |
+
)
|
| 689 |
+
parser.add_argument(
|
| 690 |
+
"--dev-segments",
|
| 691 |
+
default=None,
|
| 692 |
+
type=str,
|
| 693 |
+
help="kaldi-style segments file for validation.",
|
| 694 |
+
)
|
| 695 |
+
parser.add_argument(
|
| 696 |
+
"--dev-num-frames",
|
| 697 |
+
default=None,
|
| 698 |
+
type=str,
|
| 699 |
+
help="kaldi-style utt2num_frames file for validation.",
|
| 700 |
+
)
|
| 701 |
+
parser.add_argument(
|
| 702 |
+
"--outdir",
|
| 703 |
+
type=str,
|
| 704 |
+
required=True,
|
| 705 |
+
help="directory to save checkpoints.",
|
| 706 |
+
)
|
| 707 |
+
parser.add_argument(
|
| 708 |
+
"--config",
|
| 709 |
+
type=str,
|
| 710 |
+
required=True,
|
| 711 |
+
help="yaml format configuration file.",
|
| 712 |
+
)
|
| 713 |
+
parser.add_argument(
|
| 714 |
+
"--pretrain",
|
| 715 |
+
default="",
|
| 716 |
+
type=str,
|
| 717 |
+
nargs="?",
|
| 718 |
+
help='checkpoint file path to load pretrained params. (default="")',
|
| 719 |
+
)
|
| 720 |
+
parser.add_argument(
|
| 721 |
+
"--resume",
|
| 722 |
+
default="",
|
| 723 |
+
type=str,
|
| 724 |
+
nargs="?",
|
| 725 |
+
help='checkpoint file path to resume training. (default="")',
|
| 726 |
+
)
|
| 727 |
+
parser.add_argument(
|
| 728 |
+
"--verbose",
|
| 729 |
+
type=int,
|
| 730 |
+
default=1,
|
| 731 |
+
help="logging level. higher is more logging. (default=1)",
|
| 732 |
+
)
|
| 733 |
+
parser.add_argument("--vq-codebook", default=None, type=str)
|
| 734 |
+
# parser.add_argument("--sampling-rate", type=int)
|
| 735 |
+
# parser.add_argument("--num-mels", type=int)
|
| 736 |
+
# parser.add_argument("--hop-size", type=int)
|
| 737 |
+
# parser.add_argument("--win-length", type=int)
|
| 738 |
+
args = parser.parse_args()
|
| 739 |
+
|
| 740 |
+
# init distributed training
|
| 741 |
+
device = torch.device("cuda")
|
| 742 |
+
# effective when using fixed size inputs
|
| 743 |
+
# see https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
|
| 744 |
+
torch.backends.cudnn.benchmark = True
|
| 745 |
+
# setup for distributed training
|
| 746 |
+
# see example: https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed
|
| 747 |
+
if n_gpus == 1:
|
| 748 |
+
assert rank == 0
|
| 749 |
+
|
| 750 |
+
set_loglevel(args.verbose)
|
| 751 |
+
|
| 752 |
+
# check directory existence
|
| 753 |
+
if not os.path.exists(args.outdir):
|
| 754 |
+
os.makedirs(args.outdir)
|
| 755 |
+
|
| 756 |
+
# init process group
|
| 757 |
+
logging.info("Synchronizing between all workers.")
|
| 758 |
+
torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=n_gpus, rank=rank)
|
| 759 |
+
torch.cuda.set_device(rank)
|
| 760 |
+
logging.info("Finished init process group.")
|
| 761 |
+
|
| 762 |
+
# load and save config
|
| 763 |
+
with open(args.config) as f:
|
| 764 |
+
config = yaml.load(f, Loader=yaml.Loader)
|
| 765 |
+
config.update(vars(args))
|
| 766 |
+
config['rank'] = rank
|
| 767 |
+
config['distributed'] = True
|
| 768 |
+
config['world_size'] = n_gpus
|
| 769 |
+
config["version"] = vec2wav2.__version__ # add version info
|
| 770 |
+
if rank == 0:
|
| 771 |
+
with open(os.path.join(args.outdir, "config.yml"), "w") as f:
|
| 772 |
+
yaml.dump(config, f, Dumper=yaml.Dumper)
|
| 773 |
+
for key, value in config.items():
|
| 774 |
+
logging.info(f"{key} = {value}")
|
| 775 |
+
|
| 776 |
+
# get dataset
|
| 777 |
+
train_dataset = AudioMelSCPDataset(
|
| 778 |
+
wav_scp=args.train_wav_scp,
|
| 779 |
+
vqidx_scp=args.train_vqidx_scp,
|
| 780 |
+
mel_scp=args.train_mel_scp,
|
| 781 |
+
prompt_scp=args.train_prompt_scp,
|
| 782 |
+
utt2num_frames=args.train_num_frames,
|
| 783 |
+
segments=args.train_segments,
|
| 784 |
+
batch_frames=config.get("batch_frames", None),
|
| 785 |
+
batch_size=config.get("batch_size", None),
|
| 786 |
+
min_num_frames=config.get("min_num_frames", None),
|
| 787 |
+
max_num_frames=config.get("max_num_frames", None),
|
| 788 |
+
allow_cache=config.get("allow_cache", False), # keep compatibility
|
| 789 |
+
length_tolerance=config.get("length_tolerance", 2),
|
| 790 |
+
prompt_fold_by_2=config.get("prompt_fold_by_2", True)
|
| 791 |
+
)
|
| 792 |
+
if rank == 0:
|
| 793 |
+
logging.info(f"The number of training batches = {len(train_dataset)}.")
|
| 794 |
+
dev_dataset = AudioMelSCPDataset(
|
| 795 |
+
wav_scp=args.dev_wav_scp,
|
| 796 |
+
vqidx_scp=args.dev_vqidx_scp,
|
| 797 |
+
mel_scp=args.dev_mel_scp,
|
| 798 |
+
prompt_scp=args.dev_prompt_scp,
|
| 799 |
+
utt2num_frames=args.dev_num_frames,
|
| 800 |
+
segments=args.dev_segments,
|
| 801 |
+
min_num_frames=config.get("min_num_frames", None),
|
| 802 |
+
max_num_frames=config.get("max_num_frames", None),
|
| 803 |
+
allow_cache=config.get("allow_cache", False), # keep compatibility
|
| 804 |
+
length_tolerance=config.get("length_tolerance", 2),
|
| 805 |
+
prompt_fold_by_2=config.get("prompt_fold_by_2", True)
|
| 806 |
+
)
|
| 807 |
+
if rank == 0:
|
| 808 |
+
logging.info(f"The number of development batches = {len(dev_dataset)}.")
|
| 809 |
+
dataset = {
|
| 810 |
+
"train": train_dataset,
|
| 811 |
+
"dev": dev_dataset,
|
| 812 |
+
}
|
| 813 |
+
|
| 814 |
+
# get data loader
|
| 815 |
+
collator = Collator(
|
| 816 |
+
hop_size=config["hop_size"],
|
| 817 |
+
win_length=config["win_length"],
|
| 818 |
+
sampling_rate=config["sampling_rate"],
|
| 819 |
+
prompt_dim=config['frontend_params']['prompt_channels'],
|
| 820 |
+
prompt_fold_by_2=config.get("prompt_fold_by_2", True)
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
sampler = {
|
| 824 |
+
"train": DistributedSampler(
|
| 825 |
+
dataset=dataset["train"],
|
| 826 |
+
num_replicas=n_gpus,
|
| 827 |
+
rank=rank,
|
| 828 |
+
shuffle=True,
|
| 829 |
+
),
|
| 830 |
+
"dev": DistributedSampler(
|
| 831 |
+
dataset=dataset["dev"],
|
| 832 |
+
num_replicas=n_gpus,
|
| 833 |
+
rank=rank,
|
| 834 |
+
shuffle=False,
|
| 835 |
+
)}
|
| 836 |
+
data_loader = {
|
| 837 |
+
"train": DataLoader(
|
| 838 |
+
dataset=dataset["train"],
|
| 839 |
+
shuffle=False,
|
| 840 |
+
collate_fn=collator,
|
| 841 |
+
num_workers=config["num_workers"],
|
| 842 |
+
sampler=sampler["train"],
|
| 843 |
+
pin_memory=config["pin_memory"],
|
| 844 |
+
),
|
| 845 |
+
"dev": DataLoader(
|
| 846 |
+
dataset=dataset["dev"],
|
| 847 |
+
shuffle=False,
|
| 848 |
+
collate_fn=collator,
|
| 849 |
+
num_workers=config["num_workers"],
|
| 850 |
+
sampler=sampler["dev"],
|
| 851 |
+
pin_memory=config["pin_memory"],
|
| 852 |
+
),
|
| 853 |
+
}
|
| 854 |
+
|
| 855 |
+
# define models
|
| 856 |
+
generator_class = getattr(
|
| 857 |
+
vec2wav2.models,
|
| 858 |
+
# keep compatibility
|
| 859 |
+
config.get("generator_type", "ParallelWaveGANGenerator"),
|
| 860 |
+
)
|
| 861 |
+
discriminator_class = getattr(
|
| 862 |
+
vec2wav2.models,
|
| 863 |
+
# keep compatibility
|
| 864 |
+
config.get("discriminator_type", "ParallelWaveGANDiscriminator"),
|
| 865 |
+
)
|
| 866 |
+
model = {
|
| 867 |
+
"generator": vec2wav2.models.VEC2WAV2Generator(
|
| 868 |
+
vec2wav2.models.CTXVEC2WAVFrontend(config["prompt_net_type"], config["num_mels"], **config["frontend_params"]),
|
| 869 |
+
generator_class(**config["generator_params"])
|
| 870 |
+
).to(device),
|
| 871 |
+
"discriminator": discriminator_class(
|
| 872 |
+
**config["discriminator_params"],
|
| 873 |
+
).to(device),
|
| 874 |
+
}
|
| 875 |
+
|
| 876 |
+
# define criteria
|
| 877 |
+
criterion = {
|
| 878 |
+
"gen_adv": GeneratorAdversarialLoss(
|
| 879 |
+
# keep compatibility
|
| 880 |
+
**config.get("generator_adv_loss_params", {})
|
| 881 |
+
).to(device),
|
| 882 |
+
"dis_adv": DiscriminatorAdversarialLoss(
|
| 883 |
+
# keep compatibility
|
| 884 |
+
**config.get("discriminator_adv_loss_params", {})
|
| 885 |
+
).to(device),
|
| 886 |
+
}
|
| 887 |
+
if config.get("use_stft_loss", True): # keep compatibility
|
| 888 |
+
config["use_stft_loss"] = True
|
| 889 |
+
criterion["stft"] = MultiResolutionSTFTLoss(**config["stft_loss_params"]).to(device)
|
| 890 |
+
if config.get("use_subband_stft_loss", False): # keep compatibility
|
| 891 |
+
assert config["generator_params"]["out_channels"] > 1
|
| 892 |
+
criterion["sub_stft"] = MultiResolutionSTFTLoss(**config["subband_stft_loss_params"]).to(device)
|
| 893 |
+
else:
|
| 894 |
+
config["use_subband_stft_loss"] = False
|
| 895 |
+
if config.get("use_feat_match_loss", False): # keep compatibility
|
| 896 |
+
criterion["feat_match"] = FeatureMatchLoss(
|
| 897 |
+
# keep compatibility
|
| 898 |
+
**config.get("feat_match_loss_params", {}),
|
| 899 |
+
).to(device)
|
| 900 |
+
else:
|
| 901 |
+
config["use_feat_match_loss"] = False
|
| 902 |
+
if config.get("use_mel_loss", False): # keep compatibility
|
| 903 |
+
criterion["mel"] = MelSpectrogramLoss(**config["mel_loss_params"],).to(device)
|
| 904 |
+
else:
|
| 905 |
+
config["use_mel_loss"] = False
|
| 906 |
+
|
| 907 |
+
# define optimizers and schedulers
|
| 908 |
+
generator_optimizer_class = getattr(
|
| 909 |
+
vec2wav2.optimizers,
|
| 910 |
+
# keep compatibility
|
| 911 |
+
config.get("generator_optimizer_type", "RAdam"),
|
| 912 |
+
)
|
| 913 |
+
discriminator_optimizer_class = getattr(
|
| 914 |
+
vec2wav2.optimizers,
|
| 915 |
+
# keep compatibility
|
| 916 |
+
config.get("discriminator_optimizer_type", "RAdam"),
|
| 917 |
+
)
|
| 918 |
+
optimizer = {
|
| 919 |
+
"generator": generator_optimizer_class(
|
| 920 |
+
model["generator"].parameters(),
|
| 921 |
+
**config["generator_optimizer_params"],
|
| 922 |
+
),
|
| 923 |
+
"discriminator": discriminator_optimizer_class(
|
| 924 |
+
model["discriminator"].parameters(),
|
| 925 |
+
**config["discriminator_optimizer_params"],
|
| 926 |
+
),
|
| 927 |
+
}
|
| 928 |
+
generator_scheduler_class = getattr(
|
| 929 |
+
torch.optim.lr_scheduler,
|
| 930 |
+
# keep compatibility
|
| 931 |
+
config.get("generator_scheduler_type", "StepLR"),
|
| 932 |
+
)
|
| 933 |
+
discriminator_scheduler_class = getattr(
|
| 934 |
+
torch.optim.lr_scheduler,
|
| 935 |
+
# keep compatibility
|
| 936 |
+
config.get("discriminator_scheduler_type", "StepLR"),
|
| 937 |
+
)
|
| 938 |
+
scheduler = {
|
| 939 |
+
"generator": generator_scheduler_class(
|
| 940 |
+
optimizer=optimizer["generator"],
|
| 941 |
+
**config["generator_scheduler_params"],
|
| 942 |
+
),
|
| 943 |
+
"discriminator": discriminator_scheduler_class(
|
| 944 |
+
optimizer=optimizer["discriminator"],
|
| 945 |
+
**config["discriminator_scheduler_params"],
|
| 946 |
+
),
|
| 947 |
+
}
|
| 948 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 949 |
+
model["generator"] = DistributedDataParallel(model["generator"], device_ids=[rank], find_unused_parameters=True)
|
| 950 |
+
model["discriminator"] = DistributedDataParallel(model["discriminator"], device_ids=[rank], find_unused_parameters=True)
|
| 951 |
+
|
| 952 |
+
if rank == 0:
|
| 953 |
+
# show settings
|
| 954 |
+
logging.info(model["generator"])
|
| 955 |
+
logging.info(f"Generator has nparams: {sum([p.numel() for p in model['generator'].parameters()])}")
|
| 956 |
+
logging.info(model["discriminator"])
|
| 957 |
+
logging.info(f"Discriminator has nparams: {sum([p.numel() for p in model['discriminator'].parameters()])}")
|
| 958 |
+
logging.info(optimizer["generator"])
|
| 959 |
+
logging.info(optimizer["discriminator"])
|
| 960 |
+
|
| 961 |
+
# define trainer
|
| 962 |
+
trainer = Trainer(
|
| 963 |
+
steps=0,
|
| 964 |
+
epochs=0,
|
| 965 |
+
data_loader=data_loader,
|
| 966 |
+
sampler=sampler,
|
| 967 |
+
model=model,
|
| 968 |
+
criterion=criterion,
|
| 969 |
+
optimizer=optimizer,
|
| 970 |
+
scheduler=scheduler,
|
| 971 |
+
config=config,
|
| 972 |
+
device=device,
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
# load pretrained parameters from checkpoint
|
| 976 |
+
if len(args.pretrain) != 0:
|
| 977 |
+
trainer.load_checkpoint(args.pretrain, load_only_params=True)
|
| 978 |
+
if rank == 0:
|
| 979 |
+
logging.info(f"Successfully load parameters from {args.pretrain}.")
|
| 980 |
+
|
| 981 |
+
# resume from checkpoint
|
| 982 |
+
if len(args.resume) != 0:
|
| 983 |
+
trainer.load_checkpoint(args.resume)
|
| 984 |
+
if rank == 0:
|
| 985 |
+
logging.info(f"Successfully resumed from {args.resume}.")
|
| 986 |
+
|
| 987 |
+
# run training loop
|
| 988 |
+
try:
|
| 989 |
+
trainer.run()
|
| 990 |
+
finally:
|
| 991 |
+
if rank == 0:
|
| 992 |
+
trainer.save_checkpoint(os.path.join(config["outdir"], f"checkpoint-{trainer.steps}steps.pkl"))
|
| 993 |
+
logging.info(f"Successfully saved checkpoint @ {trainer.steps}steps.")
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
if __name__ == "__main__":
|
| 997 |
+
assert torch.cuda.is_available(), "CPU training is not allowed."
|
| 998 |
+
n_gpus = torch.cuda.device_count()
|
| 999 |
+
print(f"============> using {n_gpus} GPUS")
|
| 1000 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 1001 |
+
os.environ["MASTER_PORT"] = "8000"
|
| 1002 |
+
|
| 1003 |
+
mp.spawn(
|
| 1004 |
+
main,
|
| 1005 |
+
nprocs=n_gpus,
|
| 1006 |
+
args=(n_gpus,)
|
| 1007 |
+
)
|
vec2wav2/bin/vc.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright 2024 Yiwei Guo
|
| 4 |
+
|
| 5 |
+
""" Run VC inference with trained model """
|
| 6 |
+
|
| 7 |
+
import vec2wav2
|
| 8 |
+
from vec2wav2.ssl_models.vqw2v_extractor import Extractor as VQW2VExtractor
|
| 9 |
+
from vec2wav2.ssl_models.wavlm_extractor import Extractor as WavLMExtractor
|
| 10 |
+
# from vec2wav2.ssl_models.w2v2_extractor import Extractor as W2V2Extractor
|
| 11 |
+
import torch
|
| 12 |
+
import logging
|
| 13 |
+
import argparse
|
| 14 |
+
from vec2wav2.utils.utils import load_model, load_feat_codebook, idx2vec, read_wav_16k
|
| 15 |
+
import soundfile as sf
|
| 16 |
+
import yaml
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def configure_logging(verbose):
|
| 21 |
+
if verbose:
|
| 22 |
+
logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.DEBUG)
|
| 23 |
+
logging.getLogger().setLevel(logging.DEBUG)
|
| 24 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 25 |
+
else:
|
| 26 |
+
logging.getLogger("vec2wav2.ssl_models.WavLM").setLevel(logging.ERROR)
|
| 27 |
+
logging.getLogger().setLevel(logging.ERROR)
|
| 28 |
+
logging.basicConfig(level=logging.ERROR)
|
| 29 |
+
|
| 30 |
+
script_logger = logging.getLogger("script_logger")
|
| 31 |
+
handler = logging.StreamHandler()
|
| 32 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s | %(levelname)s | %(message)s'))
|
| 33 |
+
script_logger.addHandler(handler)
|
| 34 |
+
script_logger.setLevel(logging.INFO)
|
| 35 |
+
script_logger.propagate = False
|
| 36 |
+
return script_logger
|
| 37 |
+
|
| 38 |
+
def vc_args():
|
| 39 |
+
parser = argparse.ArgumentParser()
|
| 40 |
+
# required arguments
|
| 41 |
+
parser.add_argument("-s", "--source", default="examples/source.wav", type=str,
|
| 42 |
+
help="source wav path")
|
| 43 |
+
parser.add_argument("-t", "--target", default="examples/target.wav", type=str,
|
| 44 |
+
help="target speaker prompt path")
|
| 45 |
+
parser.add_argument("-o", "--output", default="output.wav", type=str,
|
| 46 |
+
help="path of the output wav file")
|
| 47 |
+
|
| 48 |
+
# optional arguments
|
| 49 |
+
parser.add_argument("--expdir", default="pretrained/", type=str,
|
| 50 |
+
help="path to find model checkpoints and configs. Will load expdir/generator.ckpt and expdir/config.yml.")
|
| 51 |
+
parser.add_argument('--checkpoint', default=None, type=str, help="checkpoint path (.pkl). If provided, will override expdir.")
|
| 52 |
+
parser.add_argument("--token-extractor", default="pretrained/vq-wav2vec_kmeans.pt", type=str,
|
| 53 |
+
help="checkpoint or model flag of input token extractor")
|
| 54 |
+
parser.add_argument("--prompt-extractor", default="pretrained/WavLM-Large.pt", type=str,
|
| 55 |
+
help="checkpoint or model flag of speaker prompt extractor")
|
| 56 |
+
parser.add_argument("--prompt-output-layer", default=6, type=int,
|
| 57 |
+
help="output layer when prompt is extracted from WavLM.")
|
| 58 |
+
|
| 59 |
+
parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
|
| 60 |
+
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
return args
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class VoiceConverter:
|
| 66 |
+
def __init__(self, expdir="pretrained/", token_extractor="pretrained/vq-wav2vec_kmeans.pt",
|
| 67 |
+
prompt_extractor="pretrained/WavLM-Large.pt", prompt_output_layer=6,
|
| 68 |
+
checkpoint=None, script_logger=None):
|
| 69 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 70 |
+
self.script_logger = script_logger
|
| 71 |
+
self.log_if_possible(f"Using device: {self.device}")
|
| 72 |
+
|
| 73 |
+
# set up token extractor
|
| 74 |
+
self.token_extractor = VQW2VExtractor(checkpoint=token_extractor, device=self.device)
|
| 75 |
+
feat_codebook, feat_codebook_numgroups = load_feat_codebook(self.token_extractor.get_codebook(), self.device)
|
| 76 |
+
self.feat_codebook = feat_codebook
|
| 77 |
+
self.feat_codebook_numgroups = feat_codebook_numgroups
|
| 78 |
+
self.log_if_possible(f"Successfully set up token extractor from {token_extractor}")
|
| 79 |
+
|
| 80 |
+
# set up prompt extractor
|
| 81 |
+
self.prompt_extractor = WavLMExtractor(prompt_extractor, device=self.device, output_layer=prompt_output_layer)
|
| 82 |
+
self.log_if_possible(f"Successfully set up prompt extractor from {prompt_extractor}")
|
| 83 |
+
|
| 84 |
+
# load VC model
|
| 85 |
+
self.config_path = os.path.join(expdir, "config.yml")
|
| 86 |
+
with open(self.config_path) as f:
|
| 87 |
+
self.config = yaml.load(f, Loader=yaml.Loader)
|
| 88 |
+
if checkpoint is not None:
|
| 89 |
+
checkpoint = os.path.join(expdir, checkpoint)
|
| 90 |
+
else:
|
| 91 |
+
checkpoint = os.path.join(expdir, "generator.ckpt")
|
| 92 |
+
self.model = load_model(checkpoint, self.config)
|
| 93 |
+
self.log_if_possible(f"Successfully set up VC model from {checkpoint}")
|
| 94 |
+
|
| 95 |
+
self.model.backend.remove_weight_norm()
|
| 96 |
+
self.model.eval().to(self.device)
|
| 97 |
+
|
| 98 |
+
@torch.no_grad()
|
| 99 |
+
def voice_conversion(self, source_audio, target_audio, output_path="output.wav"):
|
| 100 |
+
self.log_if_possible(f"Performing VC from {source_audio} to {target_audio}")
|
| 101 |
+
source_wav = read_wav_16k(source_audio)
|
| 102 |
+
target_wav = read_wav_16k(target_audio)
|
| 103 |
+
vq_idx = self.token_extractor.extract(source_wav).long().to(self.device)
|
| 104 |
+
|
| 105 |
+
vqvec = idx2vec(self.feat_codebook, vq_idx, self.feat_codebook_numgroups).unsqueeze(0)
|
| 106 |
+
prompt = self.prompt_extractor.extract(target_wav).unsqueeze(0).to(self.device)
|
| 107 |
+
converted = self.model.inference(vqvec, prompt)[-1].view(-1)
|
| 108 |
+
sf.write(output_path, converted.cpu().numpy(), self.config['sampling_rate'])
|
| 109 |
+
self.log_if_possible(f"Saved audio file to {output_path}")
|
| 110 |
+
return output_path
|
| 111 |
+
|
| 112 |
+
def log_if_possible(self, msg):
|
| 113 |
+
if self.script_logger is not None:
|
| 114 |
+
self.script_logger.info(msg)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
args = vc_args()
|
| 119 |
+
script_logger = configure_logging(args.verbose)
|
| 120 |
+
|
| 121 |
+
source_wav = read_wav_16k(args.source)
|
| 122 |
+
target_prompt = read_wav_16k(args.target)
|
| 123 |
+
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
voice_converter = VoiceConverter(expdir=args.expdir, token_extractor=args.token_extractor,
|
| 126 |
+
prompt_extractor=args.prompt_extractor, prompt_output_layer=args.prompt_output_layer,
|
| 127 |
+
checkpoint=args.checkpoint, script_logger=script_logger)
|
| 128 |
+
voice_converter.voice_conversion(args.source, args.target, args.output)
|
vec2wav2/datasets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .scp_dataset import * # NOQA
|
vec2wav2/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (288 Bytes). View file
|
|
|
vec2wav2/datasets/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (241 Bytes). View file
|
|
|
vec2wav2/datasets/__pycache__/scp_dataset.cpython-310.pyc
ADDED
|
Binary file (8.4 kB). View file
|
|
|
vec2wav2/datasets/__pycache__/scp_dataset.cpython-39.pyc
ADDED
|
Binary file (8.95 kB). View file
|
|
|
vec2wav2/datasets/scp_dataset.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright 2019 Tomoki Hayashi
|
| 4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 5 |
+
|
| 6 |
+
# Modified by Yiwei Guo, 2024
|
| 7 |
+
|
| 8 |
+
"""Dataset modules based on kaldi-style scp files."""
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import random
|
| 12 |
+
import copy
|
| 13 |
+
from multiprocessing import Manager
|
| 14 |
+
|
| 15 |
+
import kaldiio
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
from torch.utils.data import Dataset
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from vec2wav2.utils import HDF5ScpLoader
|
| 21 |
+
from vec2wav2.utils import NpyScpLoader
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_feats_scp_loader(feats_scp):
|
| 25 |
+
# read the first line of feats.scp file
|
| 26 |
+
with open(feats_scp) as f:
|
| 27 |
+
key, value = f.readlines()[0].replace("\n", "").split()
|
| 28 |
+
|
| 29 |
+
# check scp type
|
| 30 |
+
if ":" in value:
|
| 31 |
+
value_1, value_2 = value.split(":")
|
| 32 |
+
if value_1.endswith(".ark"):
|
| 33 |
+
# kaldi-ark case: utt_id_1 /path/to/utt_id_1.ark:index
|
| 34 |
+
return kaldiio.load_scp(feats_scp)
|
| 35 |
+
elif value_1.endswith(".h5"):
|
| 36 |
+
# hdf5 case with path in hdf5: utt_id_1 /path/to/utt_id_1.h5:feats
|
| 37 |
+
return HDF5ScpLoader(feats_scp)
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError("Not supported feats.scp type.")
|
| 40 |
+
else:
|
| 41 |
+
if value.endswith(".h5"):
|
| 42 |
+
# hdf5 case without path in hdf5: utt_id_1 /path/to/utt_id_1.h5
|
| 43 |
+
return HDF5ScpLoader(feats_scp)
|
| 44 |
+
elif value.endswith(".npy"):
|
| 45 |
+
# npy case: utt_id_1 /path/to/utt_id_1.npy
|
| 46 |
+
return NpyScpLoader(feats_scp)
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError("Not supported feats.scp type.")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AudioMelSCPDataset(Dataset):
|
| 52 |
+
"""PyTorch compatible audio and feat dataset based on kaldi-stype scp files."""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
wav_scp,
|
| 57 |
+
vqidx_scp,
|
| 58 |
+
mel_scp,
|
| 59 |
+
prompt_scp,
|
| 60 |
+
utt2num_frames=None,
|
| 61 |
+
segments=None,
|
| 62 |
+
batch_frames=None,
|
| 63 |
+
batch_size=None,
|
| 64 |
+
min_num_frames=None,
|
| 65 |
+
max_num_frames=None,
|
| 66 |
+
return_utt_id=False,
|
| 67 |
+
return_sampling_rate=False,
|
| 68 |
+
allow_cache=False,
|
| 69 |
+
length_tolerance=2,
|
| 70 |
+
prompt_fold_by_2=True
|
| 71 |
+
):
|
| 72 |
+
"""Initialize dataset.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
wav_scp (str): Kaldi-style wav.scp file.
|
| 76 |
+
vqidx_scp (str): Kaldi-style fests.scp file.
|
| 77 |
+
mel_scp (str): Kaldi-style fests.scp file.
|
| 78 |
+
segments (str): Kaldi-style segments file.
|
| 79 |
+
min_num_frames (int): Threshold to remove short feature files.
|
| 80 |
+
max_num_frames (int): Threshold to remove long feature files.
|
| 81 |
+
return_utt_id (bool): Whether to return utterance id.
|
| 82 |
+
return_sampling_rate (bool): Whether to return sampling rate.
|
| 83 |
+
allow_cache (bool): Whether to allow cache of the loaded files.
|
| 84 |
+
prompt_fold_by_2 (bool): if true, then prompt have half the length of vqidx sequence.
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
# load scp as lazy dict
|
| 88 |
+
self.audio_loader = kaldiio.load_scp(wav_scp, segments=segments)
|
| 89 |
+
self.vqidx_loader = _get_feats_scp_loader(vqidx_scp)
|
| 90 |
+
self.mel_loader = _get_feats_scp_loader(mel_scp)
|
| 91 |
+
|
| 92 |
+
self.prompt_loader = _get_feats_scp_loader(prompt_scp)
|
| 93 |
+
|
| 94 |
+
self.utt_ids = list(self.mel_loader.keys())
|
| 95 |
+
self.return_utt_id = return_utt_id
|
| 96 |
+
self.return_sampling_rate = return_sampling_rate
|
| 97 |
+
self.allow_cache = allow_cache
|
| 98 |
+
|
| 99 |
+
utt2num_frames_loader = None
|
| 100 |
+
if utt2num_frames is not None:
|
| 101 |
+
with open(utt2num_frames, 'r') as f:
|
| 102 |
+
utt2num_frames_loader = dict([(x.split()[0], int(x.split()[1])) for x in f.readlines()])
|
| 103 |
+
else:
|
| 104 |
+
utt2num_frames_loader = dict([(k, mel.shape[0]) for k, mel in self.mel_loader.items()])
|
| 105 |
+
|
| 106 |
+
self.utt2num_frames_loader = utt2num_frames_loader
|
| 107 |
+
|
| 108 |
+
# filter by threshold
|
| 109 |
+
if (min_num_frames or max_num_frames) is not None:
|
| 110 |
+
mel_lengths = [utt2num_frames_loader[key] for key in self.utt_ids]
|
| 111 |
+
idxs = [
|
| 112 |
+
idx
|
| 113 |
+
for idx in range(len(self.utt_ids))
|
| 114 |
+
if (min_num_frames and mel_lengths[idx] >= min_num_frames) and (max_num_frames and mel_lengths[idx] <= max_num_frames)
|
| 115 |
+
]
|
| 116 |
+
if len(self.utt_ids) != len(idxs):
|
| 117 |
+
logging.warning(
|
| 118 |
+
f"Some files are filtered by mel length threshold "
|
| 119 |
+
f"({len(self.utt_ids)} -> {len(idxs)})."
|
| 120 |
+
)
|
| 121 |
+
self.utt_ids = [self.utt_ids[idx] for idx in idxs]
|
| 122 |
+
|
| 123 |
+
# batchify
|
| 124 |
+
if batch_frames is not None:
|
| 125 |
+
self.batches = self.batchify(utt2num_frames_loader, batch_frames=batch_frames)
|
| 126 |
+
elif batch_size is not None:
|
| 127 |
+
self.batches = self.batchify(utt2num_frames_loader, batch_size=batch_size)
|
| 128 |
+
else:
|
| 129 |
+
self.batches = [[utt_id] for utt_id in self.utt_ids]
|
| 130 |
+
|
| 131 |
+
if allow_cache:
|
| 132 |
+
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
|
| 133 |
+
self.manager = Manager()
|
| 134 |
+
self.caches = self.manager.dict()
|
| 135 |
+
self.length_tolerance = length_tolerance
|
| 136 |
+
if prompt_fold_by_2:
|
| 137 |
+
self.prompt_len_factor = 2
|
| 138 |
+
else:
|
| 139 |
+
self.prompt_len_factor = 1
|
| 140 |
+
|
| 141 |
+
def batchify(self, utt2num_frames_loader, batch_frames=None, batch_size=None, min_batch_size=1, drop_last=True):
|
| 142 |
+
|
| 143 |
+
assert batch_size is None or batch_size > min_batch_size
|
| 144 |
+
|
| 145 |
+
batches = []
|
| 146 |
+
batch = []
|
| 147 |
+
accum_num_frames = 0
|
| 148 |
+
utt_ids_set = set(self.utt_ids)
|
| 149 |
+
for utt_id, mel_length in tqdm(sorted(list(utt2num_frames_loader.items()), key=lambda x: x[1], reverse=True)):
|
| 150 |
+
if utt_id not in utt_ids_set:
|
| 151 |
+
continue
|
| 152 |
+
if (batch_frames is not None and accum_num_frames + mel_length > batch_frames and len(batch) > min_batch_size) or (batch_size is not None and len(batch) == batch_size):
|
| 153 |
+
batches.append(batch)
|
| 154 |
+
batch = []
|
| 155 |
+
accum_num_frames = 0
|
| 156 |
+
batch.append(utt_id)
|
| 157 |
+
accum_num_frames += mel_length
|
| 158 |
+
if len(batch) > min_batch_size and not drop_last:
|
| 159 |
+
batches.append(batch)
|
| 160 |
+
return batches
|
| 161 |
+
|
| 162 |
+
def __getitem__(self, idx):
|
| 163 |
+
"""Get specified idx items.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
idx (int): Index of the item.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
str: Utterance id (only in return_utt_id = True).
|
| 170 |
+
ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True).
|
| 171 |
+
ndarrays: Features (T', C).
|
| 172 |
+
|
| 173 |
+
"""
|
| 174 |
+
batch = self.batches[idx]
|
| 175 |
+
batch_items = []
|
| 176 |
+
|
| 177 |
+
for utt_id in batch:
|
| 178 |
+
if self.allow_cache and self.caches.get(utt_id) is not None:
|
| 179 |
+
items = self.caches[utt_id]
|
| 180 |
+
else:
|
| 181 |
+
fs, audio = self.audio_loader[utt_id]
|
| 182 |
+
mel = self.mel_loader[utt_id]
|
| 183 |
+
prompt = self.prompt_loader[utt_id]
|
| 184 |
+
vqidx = self.vqidx_loader[utt_id]
|
| 185 |
+
|
| 186 |
+
min_len = min(len(mel), len(vqidx), len(prompt)*self.prompt_len_factor)
|
| 187 |
+
assert ((abs(len(mel) - min_len) <= self.length_tolerance) and
|
| 188 |
+
(abs(len(vqidx) - min_len) <= self.length_tolerance) and
|
| 189 |
+
(abs(len(prompt)*self.prompt_len_factor - min_len) <= self.length_tolerance)), \
|
| 190 |
+
f"Audio feature lengths difference exceeds length tolerance for {utt_id}"
|
| 191 |
+
mel, vqidx, prompt = mel[:min_len], vqidx[:min_len], prompt[:min_len//self.prompt_len_factor]
|
| 192 |
+
|
| 193 |
+
# normalize audio signal to be [-1, 1]
|
| 194 |
+
audio = audio.astype(np.float32)
|
| 195 |
+
audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit
|
| 196 |
+
|
| 197 |
+
if self.return_sampling_rate:
|
| 198 |
+
audio = (audio, fs)
|
| 199 |
+
|
| 200 |
+
if self.return_utt_id:
|
| 201 |
+
items = utt_id, audio, vqidx, mel, prompt
|
| 202 |
+
else:
|
| 203 |
+
items = audio, vqidx, mel, prompt
|
| 204 |
+
|
| 205 |
+
if self.allow_cache:
|
| 206 |
+
self.caches[utt_id] = items
|
| 207 |
+
|
| 208 |
+
batch_items.append(items)
|
| 209 |
+
|
| 210 |
+
return batch_items
|
| 211 |
+
|
| 212 |
+
def __len__(self):
|
| 213 |
+
"""Return dataset length.
|
| 214 |
+
Returns:
|
| 215 |
+
int: The length of dataset.
|
| 216 |
+
"""
|
| 217 |
+
return len(self.batches)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class MelSCPDataset(Dataset):
|
| 221 |
+
"""PyTorch compatible feat dataset based on kaldi-stype scp files."""
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
vqidx_scp,
|
| 226 |
+
prompt_scp,
|
| 227 |
+
return_utt_id=False,
|
| 228 |
+
allow_cache=False,
|
| 229 |
+
):
|
| 230 |
+
"""Initialize dataset.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
vqidx_scp (str): Kaldi-style fests.scp file.
|
| 234 |
+
prompt_scp (str): Kaldi-style scp file. In this file, every utt is associated with its prompt's mel-spectrogram.
|
| 235 |
+
min_num_frames (int): Threshold to remove short feature files.
|
| 236 |
+
max_num_frames (int): Threshold to remove long feature files.
|
| 237 |
+
return_utt_id (bool): Whether to return utterance id.
|
| 238 |
+
allow_cache (bool): Whether to allow cache of the loaded files.
|
| 239 |
+
"""
|
| 240 |
+
# load scp as lazy dict
|
| 241 |
+
vqidx_loader = _get_feats_scp_loader(vqidx_scp)
|
| 242 |
+
self.prompt_loader = _get_feats_scp_loader(prompt_scp)
|
| 243 |
+
# self.prompt_loader = dict()
|
| 244 |
+
# with open(prompt_scp, 'r') as fr:
|
| 245 |
+
# for line in fr.readlines():
|
| 246 |
+
# terms = line.strip().split()
|
| 247 |
+
# self.prompt_loader[terms[0]] = terms[1]
|
| 248 |
+
vqidx_keys = list(set(self.prompt_loader.keys()) & set(vqidx_loader.keys()))
|
| 249 |
+
|
| 250 |
+
# NOTE: this dataset does not apply filtering, because it is usually used for decoding
|
| 251 |
+
|
| 252 |
+
self.vqidx_loader = vqidx_loader
|
| 253 |
+
self.utt_ids = vqidx_keys
|
| 254 |
+
self.return_utt_id = return_utt_id
|
| 255 |
+
self.allow_cache = allow_cache
|
| 256 |
+
|
| 257 |
+
if allow_cache:
|
| 258 |
+
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0
|
| 259 |
+
self.manager = Manager()
|
| 260 |
+
self.caches = self.manager.list()
|
| 261 |
+
self.caches += [() for _ in range(len(self.utt_ids))]
|
| 262 |
+
|
| 263 |
+
def __getitem__(self, idx):
|
| 264 |
+
"""Get specified idx items.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
idx (int): Index of the item.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
str: Utterance id (only in return_utt_id = True).
|
| 271 |
+
ndarray: Feature (T', C).
|
| 272 |
+
|
| 273 |
+
"""
|
| 274 |
+
if self.allow_cache and len(self.caches[idx]) != 0:
|
| 275 |
+
return self.caches[idx]
|
| 276 |
+
|
| 277 |
+
utt_id = self.utt_ids[idx]
|
| 278 |
+
vqidx = self.vqidx_loader[utt_id].astype(int)
|
| 279 |
+
|
| 280 |
+
# prompt = torch.load(self.prompt_loader[utt_id]).float().numpy()
|
| 281 |
+
prompt = self.prompt_loader[utt_id]
|
| 282 |
+
|
| 283 |
+
if self.return_utt_id:
|
| 284 |
+
items = utt_id, vqidx, prompt
|
| 285 |
+
else:
|
| 286 |
+
items = vqidx, prompt
|
| 287 |
+
|
| 288 |
+
if self.allow_cache:
|
| 289 |
+
self.caches[idx] = items
|
| 290 |
+
|
| 291 |
+
return items
|
| 292 |
+
|
| 293 |
+
def __len__(self):
|
| 294 |
+
"""Return dataset length.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
int: The length of dataset.
|
| 298 |
+
|
| 299 |
+
"""
|
| 300 |
+
return len(self.utt_ids)
|
vec2wav2/distributed/__init__.py
ADDED
|
File without changes
|
vec2wav2/distributed/launch.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""Distributed process launcher.
|
| 5 |
+
|
| 6 |
+
This code is modified from https://github.com/pytorch/pytorch/blob/v1.3.0/torch/distributed/launch.py.
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
import os
|
| 10 |
+
import subprocess
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
from argparse import ArgumentParser
|
| 14 |
+
from argparse import REMAINDER
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_args():
|
| 18 |
+
"""Parse arguments."""
|
| 19 |
+
parser = ArgumentParser(
|
| 20 |
+
description="PyTorch distributed training launch "
|
| 21 |
+
"helper utilty that will spawn up "
|
| 22 |
+
"multiple distributed processes"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# Optional arguments for the launch helper
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--nnodes",
|
| 28 |
+
type=int,
|
| 29 |
+
default=1,
|
| 30 |
+
help="The number of nodes to use for distributed " "training",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--node_rank",
|
| 34 |
+
type=int,
|
| 35 |
+
default=0,
|
| 36 |
+
help="The rank of the node for multi-node distributed " "training",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--nproc_per_node",
|
| 40 |
+
type=int,
|
| 41 |
+
default=1,
|
| 42 |
+
help="The number of processes to launch on each node, "
|
| 43 |
+
"for GPU training, this is recommended to be set "
|
| 44 |
+
"to the number of GPUs in your system so that "
|
| 45 |
+
"each process can be bound to a single GPU.",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--master_addr",
|
| 49 |
+
default="127.0.0.1",
|
| 50 |
+
type=str,
|
| 51 |
+
help="Master node (rank 0)'s address, should be either "
|
| 52 |
+
"the IP address or the hostname of node 0, for "
|
| 53 |
+
"single node multi-proc training, the "
|
| 54 |
+
"--master_addr can simply be 127.0.0.1",
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--master_port",
|
| 58 |
+
default=29500,
|
| 59 |
+
type=int,
|
| 60 |
+
help="Master node (rank 0)'s free port that needs to "
|
| 61 |
+
"be used for communciation during distributed "
|
| 62 |
+
"training",
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--use_env",
|
| 66 |
+
default=False,
|
| 67 |
+
action="store_true",
|
| 68 |
+
help="Use environment variable to pass "
|
| 69 |
+
"'local rank'. For legacy reasons, the default value is False. "
|
| 70 |
+
"If set to True, the script will not pass "
|
| 71 |
+
"--local_rank as argument, and will instead set LOCAL_RANK.",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"-m",
|
| 75 |
+
"--module",
|
| 76 |
+
default=False,
|
| 77 |
+
action="store_true",
|
| 78 |
+
help="Changes each process to interpret the launch script "
|
| 79 |
+
"as a python module, executing with the same behavior as"
|
| 80 |
+
"'python -m'.",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"-c",
|
| 84 |
+
"--command",
|
| 85 |
+
default=False,
|
| 86 |
+
action="store_true",
|
| 87 |
+
help="Changes each process to interpret the launch script " "as a command.",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# positional
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"training_script",
|
| 93 |
+
type=str,
|
| 94 |
+
help="The full path to the single GPU training "
|
| 95 |
+
"program/script/command to be launched in parallel, "
|
| 96 |
+
"followed by all the arguments for the "
|
| 97 |
+
"training script",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# rest from the training program
|
| 101 |
+
parser.add_argument("training_script_args", nargs=REMAINDER)
|
| 102 |
+
return parser.parse_args()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
"""Launch distributed processes."""
|
| 107 |
+
args = parse_args()
|
| 108 |
+
|
| 109 |
+
# world size in terms of number of processes
|
| 110 |
+
dist_world_size = args.nproc_per_node * args.nnodes
|
| 111 |
+
|
| 112 |
+
# set PyTorch distributed related environmental variables
|
| 113 |
+
current_env = os.environ.copy()
|
| 114 |
+
current_env["MASTER_ADDR"] = args.master_addr
|
| 115 |
+
current_env["MASTER_PORT"] = str(args.master_port)
|
| 116 |
+
current_env["WORLD_SIZE"] = str(dist_world_size)
|
| 117 |
+
|
| 118 |
+
processes = []
|
| 119 |
+
|
| 120 |
+
if "OMP_NUM_THREADS" not in os.environ and args.nproc_per_node > 1:
|
| 121 |
+
current_env["OMP_NUM_THREADS"] = str(1)
|
| 122 |
+
print(
|
| 123 |
+
"*****************************************\n"
|
| 124 |
+
"Setting OMP_NUM_THREADS environment variable for each process "
|
| 125 |
+
"to be {} in default, to avoid your system being overloaded, "
|
| 126 |
+
"please further tune the variable for optimal performance in "
|
| 127 |
+
"your application as needed. \n"
|
| 128 |
+
"*****************************************".format(
|
| 129 |
+
current_env["OMP_NUM_THREADS"]
|
| 130 |
+
)
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
for local_rank in range(0, args.nproc_per_node):
|
| 134 |
+
# each process's rank
|
| 135 |
+
dist_rank = args.nproc_per_node * args.node_rank + local_rank
|
| 136 |
+
current_env["RANK"] = str(dist_rank)
|
| 137 |
+
current_env["LOCAL_RANK"] = str(local_rank)
|
| 138 |
+
|
| 139 |
+
# spawn the processes
|
| 140 |
+
if args.command:
|
| 141 |
+
cmd = [args.training_script]
|
| 142 |
+
else:
|
| 143 |
+
cmd = [sys.executable, "-u"]
|
| 144 |
+
if args.module:
|
| 145 |
+
cmd.append("-m")
|
| 146 |
+
cmd.append(args.training_script)
|
| 147 |
+
|
| 148 |
+
if not args.use_env:
|
| 149 |
+
cmd.append("--local_rank={}".format(local_rank))
|
| 150 |
+
|
| 151 |
+
cmd.extend(args.training_script_args)
|
| 152 |
+
|
| 153 |
+
process = subprocess.Popen(cmd, env=current_env)
|
| 154 |
+
processes.append(process)
|
| 155 |
+
|
| 156 |
+
for process in processes:
|
| 157 |
+
process.wait()
|
| 158 |
+
if process.returncode != 0:
|
| 159 |
+
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
main()
|
vec2wav2/layers/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .causal_conv import * # NOQA
|
| 2 |
+
from .pqmf import * # NOQA
|
| 3 |
+
from .residual_block import * # NOQA
|
| 4 |
+
from .residual_stack import * # NOQA
|
| 5 |
+
from .tade_res_block import * # NOQA
|
| 6 |
+
from .upsample import * # NOQA
|
vec2wav2/layers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (344 Bytes). View file
|
|
|
vec2wav2/layers/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (353 Bytes). View file
|
|
|
vec2wav2/layers/__pycache__/activations.cpython-310.pyc
ADDED
|
Binary file (6.64 kB). View file
|
|
|
vec2wav2/layers/__pycache__/causal_conv.cpython-310.pyc
ADDED
|
Binary file (2.23 kB). View file
|
|
|
vec2wav2/layers/__pycache__/causal_conv.cpython-39.pyc
ADDED
|
Binary file (2.24 kB). View file
|
|
|
vec2wav2/layers/__pycache__/pqmf.cpython-310.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
vec2wav2/layers/__pycache__/pqmf.cpython-39.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
vec2wav2/layers/__pycache__/residual_block.cpython-310.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
vec2wav2/layers/__pycache__/residual_block.cpython-39.pyc
ADDED
|
Binary file (6.18 kB). View file
|
|
|
vec2wav2/layers/__pycache__/residual_stack.cpython-310.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
vec2wav2/layers/__pycache__/residual_stack.cpython-39.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
vec2wav2/layers/__pycache__/tade_res_block.cpython-310.pyc
ADDED
|
Binary file (3.59 kB). View file
|
|
|
vec2wav2/layers/__pycache__/tade_res_block.cpython-39.pyc
ADDED
|
Binary file (3.56 kB). View file
|
|
|
vec2wav2/layers/__pycache__/upsample.cpython-310.pyc
ADDED
|
Binary file (6.01 kB). View file
|
|
|
vec2wav2/layers/__pycache__/upsample.cpython-39.pyc
ADDED
|
Binary file (6 kB). View file
|
|
|
vec2wav2/layers/activations.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
# Modified by Yiwei Guo, 2024
|
| 5 |
+
# including conditioned snakebeta activation
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn, sin, pow
|
| 9 |
+
from torch.nn import Parameter
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Snake(nn.Module):
|
| 13 |
+
'''
|
| 14 |
+
Implementation of a sine-based periodic activation function
|
| 15 |
+
Shape:
|
| 16 |
+
- Input: (B, C, T)
|
| 17 |
+
- Output: (B, C, T), same shape as the input
|
| 18 |
+
Parameters:
|
| 19 |
+
- alpha - trainable parameter
|
| 20 |
+
References:
|
| 21 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 22 |
+
https://arxiv.org/abs/2006.08195
|
| 23 |
+
Examples:
|
| 24 |
+
>>> a1 = snake(256)
|
| 25 |
+
>>> x = torch.randn(256)
|
| 26 |
+
>>> x = a1(x)
|
| 27 |
+
'''
|
| 28 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 29 |
+
'''
|
| 30 |
+
Initialization.
|
| 31 |
+
INPUT:
|
| 32 |
+
- in_features: shape of the input
|
| 33 |
+
- alpha: trainable parameter
|
| 34 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 35 |
+
alpha will be trained along with the rest of your model.
|
| 36 |
+
'''
|
| 37 |
+
super(Snake, self).__init__()
|
| 38 |
+
self.in_features = in_features
|
| 39 |
+
|
| 40 |
+
# initialize alpha
|
| 41 |
+
self.alpha_logscale = alpha_logscale
|
| 42 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 43 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 44 |
+
else: # linear scale alphas initialized to ones
|
| 45 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 46 |
+
|
| 47 |
+
self.alpha.requires_grad = alpha_trainable
|
| 48 |
+
|
| 49 |
+
self.no_div_by_zero = 0.000000001
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
'''
|
| 53 |
+
Forward pass of the function.
|
| 54 |
+
Applies the function to the input elementwise.
|
| 55 |
+
Snake := x + 1/a * sin^2 (xa)
|
| 56 |
+
'''
|
| 57 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 58 |
+
if self.alpha_logscale:
|
| 59 |
+
alpha = torch.exp(alpha)
|
| 60 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 61 |
+
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SnakeBeta(nn.Module):
|
| 66 |
+
'''
|
| 67 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 68 |
+
Shape:
|
| 69 |
+
- Input: (B, C, T)
|
| 70 |
+
- Output: (B, C, T), same shape as the input
|
| 71 |
+
Parameters:
|
| 72 |
+
- alpha - trainable parameter that controls frequency
|
| 73 |
+
- beta - trainable parameter that controls magnitude
|
| 74 |
+
References:
|
| 75 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 76 |
+
https://arxiv.org/abs/2006.08195
|
| 77 |
+
Examples:
|
| 78 |
+
>>> a1 = snakebeta(256)
|
| 79 |
+
>>> x = torch.randn(256)
|
| 80 |
+
>>> x = a1(x)
|
| 81 |
+
'''
|
| 82 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 83 |
+
'''
|
| 84 |
+
Initialization.
|
| 85 |
+
INPUT:
|
| 86 |
+
- in_features: shape of the input
|
| 87 |
+
- alpha - trainable parameter that controls frequency
|
| 88 |
+
- beta - trainable parameter that controls magnitude
|
| 89 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 90 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 91 |
+
alpha will be trained along with the rest of your model.
|
| 92 |
+
'''
|
| 93 |
+
super(SnakeBeta, self).__init__()
|
| 94 |
+
self.in_features = in_features
|
| 95 |
+
|
| 96 |
+
# initialize alpha
|
| 97 |
+
self.alpha_logscale = alpha_logscale
|
| 98 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 99 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 100 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 101 |
+
else: # linear scale alphas initialized to ones
|
| 102 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 103 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
| 104 |
+
|
| 105 |
+
self.alpha.requires_grad = alpha_trainable
|
| 106 |
+
self.beta.requires_grad = alpha_trainable
|
| 107 |
+
|
| 108 |
+
self.no_div_by_zero = 0.000000001
|
| 109 |
+
|
| 110 |
+
def forward(self, x, cond=None):
|
| 111 |
+
'''
|
| 112 |
+
Forward pass of the function.
|
| 113 |
+
Applies the function to the input elementwise.
|
| 114 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 115 |
+
'''
|
| 116 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 117 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 118 |
+
if self.alpha_logscale:
|
| 119 |
+
alpha = torch.exp(alpha)
|
| 120 |
+
beta = torch.exp(beta)
|
| 121 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 122 |
+
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class SnakeBetaWithCondition(nn.Module):
|
| 127 |
+
'''
|
| 128 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 129 |
+
Shape:
|
| 130 |
+
- Input: (B, C, T)
|
| 131 |
+
- Condition: (B, D), where D-dimension will be mapped to C dimensions
|
| 132 |
+
- Output: (B, C, T), same shape as the input
|
| 133 |
+
Parameters:
|
| 134 |
+
- alpha - trainable parameter that controls frequency
|
| 135 |
+
- beta - trainable parameter that controls magnitude
|
| 136 |
+
- condition_alpha_prenet - trainable parameter that controls alpha and beta using condition
|
| 137 |
+
References:
|
| 138 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 139 |
+
https://arxiv.org/abs/2006.08195
|
| 140 |
+
Examples:
|
| 141 |
+
>>> a1 = snakebeta(256, 128)
|
| 142 |
+
>>> x = torch.randn(256)
|
| 143 |
+
>>> cond = torch.randn(128)
|
| 144 |
+
>>> x = a1(x, cond)
|
| 145 |
+
'''
|
| 146 |
+
def __init__(self, in_features, condition_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 147 |
+
'''
|
| 148 |
+
Initialization.
|
| 149 |
+
INPUT:
|
| 150 |
+
- in_features: dimension of the input
|
| 151 |
+
- condition_features: dimension of the condition vectors
|
| 152 |
+
- alpha - trainable parameter that controls frequency
|
| 153 |
+
- beta - trainable parameter that controls magnitude
|
| 154 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 155 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 156 |
+
alpha, beta will be trained along with the rest of your model.
|
| 157 |
+
'''
|
| 158 |
+
super(SnakeBetaWithCondition, self).__init__()
|
| 159 |
+
self.in_features = in_features
|
| 160 |
+
|
| 161 |
+
self.condition_alpha_prenet = torch.nn.Linear(condition_features, in_features)
|
| 162 |
+
# self.condition_beta_prenet = torch.nn.Linear(condition_features, in_features)
|
| 163 |
+
|
| 164 |
+
# initialize alpha
|
| 165 |
+
self.alpha_logscale = alpha_logscale
|
| 166 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 167 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 168 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 169 |
+
else: # linear scale alphas initialized to ones
|
| 170 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 171 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
| 172 |
+
|
| 173 |
+
self.alpha.requires_grad = alpha_trainable
|
| 174 |
+
self.beta.requires_grad = alpha_trainable
|
| 175 |
+
|
| 176 |
+
self.no_div_by_zero = 0.000000001
|
| 177 |
+
|
| 178 |
+
def forward(self, x, condition):
|
| 179 |
+
'''
|
| 180 |
+
condition: [B, D]
|
| 181 |
+
Forward pass of the function.
|
| 182 |
+
Applies the function to the input elementwise.
|
| 183 |
+
SnakeBeta := x + 1/b * sin^2 (xa)
|
| 184 |
+
'''
|
| 185 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 186 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 187 |
+
if self.alpha_logscale:
|
| 188 |
+
alpha = torch.exp(alpha)
|
| 189 |
+
beta = torch.exp(beta)
|
| 190 |
+
|
| 191 |
+
condition = torch.tanh(self.condition_alpha_prenet(condition).unsqueeze(-1)) # Same prenet for both alpha and beta, to save parameters
|
| 192 |
+
alpha = alpha + condition
|
| 193 |
+
beta = beta + 0.5 * condition # multiply 0.5 for avoiding beta being too small
|
| 194 |
+
|
| 195 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 196 |
+
|
| 197 |
+
return x
|
vec2wav2/layers/causal_conv.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright 2020 Tomoki Hayashi
|
| 4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 5 |
+
|
| 6 |
+
"""Causal convolusion layer modules."""
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CausalConv1d(torch.nn.Module):
|
| 13 |
+
"""CausalConv1d module with customized initialization."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_channels,
|
| 18 |
+
out_channels,
|
| 19 |
+
kernel_size,
|
| 20 |
+
dilation=1,
|
| 21 |
+
bias=True,
|
| 22 |
+
pad="ConstantPad1d",
|
| 23 |
+
pad_params={"value": 0.0},
|
| 24 |
+
):
|
| 25 |
+
"""Initialize CausalConv1d module."""
|
| 26 |
+
super(CausalConv1d, self).__init__()
|
| 27 |
+
self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params)
|
| 28 |
+
self.conv = torch.nn.Conv1d(
|
| 29 |
+
in_channels, out_channels, kernel_size, dilation=dilation, bias=bias
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
"""Calculate forward propagation.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
x (Tensor): Input tensor (B, in_channels, T).
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Tensor: Output tensor (B, out_channels, T).
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
return self.conv(self.pad(x))[:, :, : x.size(2)]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class CausalConvTranspose1d(torch.nn.Module):
|
| 46 |
+
"""CausalConvTranspose1d module with customized initialization."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
|
| 49 |
+
"""Initialize CausalConvTranspose1d module."""
|
| 50 |
+
super(CausalConvTranspose1d, self).__init__()
|
| 51 |
+
self.deconv = torch.nn.ConvTranspose1d(
|
| 52 |
+
in_channels, out_channels, kernel_size, stride, bias=bias
|
| 53 |
+
)
|
| 54 |
+
self.stride = stride
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
"""Calculate forward propagation.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
x (Tensor): Input tensor (B, in_channels, T_in).
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Tensor: Output tensor (B, out_channels, T_out).
|
| 64 |
+
|
| 65 |
+
"""
|
| 66 |
+
return self.deconv(x)[:, :, : -self.stride]
|
vec2wav2/layers/pqmf.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright 2020 Tomoki Hayashi
|
| 4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 5 |
+
|
| 6 |
+
"""Pseudo QMF modules."""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from scipy.signal import kaiser
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
|
| 16 |
+
"""Design prototype filter for PQMF.
|
| 17 |
+
|
| 18 |
+
This method is based on `A Kaiser window approach for the design of prototype
|
| 19 |
+
filters of cosine modulated filterbanks`_.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
taps (int): The number of filter taps.
|
| 23 |
+
cutoff_ratio (float): Cut-off frequency ratio.
|
| 24 |
+
beta (float): Beta coefficient for kaiser window.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
ndarray: Impluse response of prototype filter (taps + 1,).
|
| 28 |
+
|
| 29 |
+
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
| 30 |
+
https://ieeexplore.ieee.org/abstract/document/681427
|
| 31 |
+
|
| 32 |
+
"""
|
| 33 |
+
# check the arguments are valid
|
| 34 |
+
assert taps % 2 == 0, "The number of taps mush be even number."
|
| 35 |
+
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
|
| 36 |
+
|
| 37 |
+
# make initial filter
|
| 38 |
+
omega_c = np.pi * cutoff_ratio
|
| 39 |
+
with np.errstate(invalid="ignore"):
|
| 40 |
+
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
|
| 41 |
+
np.pi * (np.arange(taps + 1) - 0.5 * taps)
|
| 42 |
+
)
|
| 43 |
+
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
|
| 44 |
+
|
| 45 |
+
# apply kaiser window
|
| 46 |
+
w = kaiser(taps + 1, beta)
|
| 47 |
+
h = h_i * w
|
| 48 |
+
|
| 49 |
+
return h
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class PQMF(torch.nn.Module):
|
| 53 |
+
"""PQMF module.
|
| 54 |
+
|
| 55 |
+
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
| 56 |
+
|
| 57 |
+
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
| 58 |
+
https://ieeexplore.ieee.org/document/258122
|
| 59 |
+
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
|
| 63 |
+
"""Initilize PQMF module.
|
| 64 |
+
|
| 65 |
+
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
|
| 66 |
+
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
subbands (int): The number of subbands.
|
| 70 |
+
taps (int): The number of filter taps.
|
| 71 |
+
cutoff_ratio (float): Cut-off frequency ratio.
|
| 72 |
+
beta (float): Beta coefficient for kaiser window.
|
| 73 |
+
|
| 74 |
+
"""
|
| 75 |
+
super(PQMF, self).__init__()
|
| 76 |
+
|
| 77 |
+
# build analysis & synthesis filter coefficients
|
| 78 |
+
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
| 79 |
+
h_analysis = np.zeros((subbands, len(h_proto)))
|
| 80 |
+
h_synthesis = np.zeros((subbands, len(h_proto)))
|
| 81 |
+
for k in range(subbands):
|
| 82 |
+
h_analysis[k] = (
|
| 83 |
+
2
|
| 84 |
+
* h_proto
|
| 85 |
+
* np.cos(
|
| 86 |
+
(2 * k + 1)
|
| 87 |
+
* (np.pi / (2 * subbands))
|
| 88 |
+
* (np.arange(taps + 1) - (taps / 2))
|
| 89 |
+
+ (-1) ** k * np.pi / 4
|
| 90 |
+
)
|
| 91 |
+
)
|
| 92 |
+
h_synthesis[k] = (
|
| 93 |
+
2
|
| 94 |
+
* h_proto
|
| 95 |
+
* np.cos(
|
| 96 |
+
(2 * k + 1)
|
| 97 |
+
* (np.pi / (2 * subbands))
|
| 98 |
+
* (np.arange(taps + 1) - (taps / 2))
|
| 99 |
+
- (-1) ** k * np.pi / 4
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# convert to tensor
|
| 104 |
+
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
|
| 105 |
+
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
|
| 106 |
+
|
| 107 |
+
# register coefficients as beffer
|
| 108 |
+
self.register_buffer("analysis_filter", analysis_filter)
|
| 109 |
+
self.register_buffer("synthesis_filter", synthesis_filter)
|
| 110 |
+
|
| 111 |
+
# filter for downsampling & upsampling
|
| 112 |
+
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
|
| 113 |
+
for k in range(subbands):
|
| 114 |
+
updown_filter[k, k, 0] = 1.0
|
| 115 |
+
self.register_buffer("updown_filter", updown_filter)
|
| 116 |
+
self.subbands = subbands
|
| 117 |
+
|
| 118 |
+
# keep padding info
|
| 119 |
+
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
| 120 |
+
|
| 121 |
+
def analysis(self, x):
|
| 122 |
+
"""Analysis with PQMF.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
x (Tensor): Input tensor (B, 1, T).
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Tensor: Output tensor (B, subbands, T // subbands).
|
| 129 |
+
|
| 130 |
+
"""
|
| 131 |
+
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
| 132 |
+
return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
| 133 |
+
|
| 134 |
+
def synthesis(self, x):
|
| 135 |
+
"""Synthesis with PQMF.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
x (Tensor): Input tensor (B, subbands, T // subbands).
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
Tensor: Output tensor (B, 1, T).
|
| 142 |
+
|
| 143 |
+
"""
|
| 144 |
+
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
|
| 145 |
+
# Not sure this is the correct way, it is better to check again.
|
| 146 |
+
# TODO(kan-bayashi): Understand the reconstruction procedure
|
| 147 |
+
x = F.conv_transpose1d(
|
| 148 |
+
x, self.updown_filter * self.subbands, stride=self.subbands
|
| 149 |
+
)
|
| 150 |
+
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
vec2wav2/layers/residual_block.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""Residual block modules.
|
| 4 |
+
|
| 5 |
+
References:
|
| 6 |
+
- https://github.com/r9y9/wavenet_vocoder
|
| 7 |
+
- https://github.com/jik876/hifi-gan
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Conv1d(torch.nn.Conv1d):
|
| 18 |
+
"""Conv1d module with customized initialization."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, *args, **kwargs):
|
| 21 |
+
"""Initialize Conv1d module."""
|
| 22 |
+
super(Conv1d, self).__init__(*args, **kwargs)
|
| 23 |
+
|
| 24 |
+
def reset_parameters(self):
|
| 25 |
+
"""Reset parameters."""
|
| 26 |
+
torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
|
| 27 |
+
if self.bias is not None:
|
| 28 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Conv1d1x1(Conv1d):
|
| 32 |
+
"""1x1 Conv1d with customized initialization."""
|
| 33 |
+
|
| 34 |
+
def __init__(self, in_channels, out_channels, bias):
|
| 35 |
+
"""Initialize 1x1 Conv1d module."""
|
| 36 |
+
super(Conv1d1x1, self).__init__(
|
| 37 |
+
in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class WaveNetResidualBlock(torch.nn.Module):
|
| 42 |
+
"""Residual block module in WaveNet."""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
kernel_size=3,
|
| 47 |
+
residual_channels=64,
|
| 48 |
+
gate_channels=128,
|
| 49 |
+
skip_channels=64,
|
| 50 |
+
aux_channels=80,
|
| 51 |
+
dropout=0.0,
|
| 52 |
+
dilation=1,
|
| 53 |
+
bias=True,
|
| 54 |
+
use_causal_conv=False,
|
| 55 |
+
):
|
| 56 |
+
"""Initialize WaveNetResidualBlock module.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
| 60 |
+
residual_channels (int): Number of channels for residual connection.
|
| 61 |
+
skip_channels (int): Number of channels for skip connection.
|
| 62 |
+
aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
|
| 63 |
+
dropout (float): Dropout probability.
|
| 64 |
+
dilation (int): Dilation factor.
|
| 65 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
| 66 |
+
use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.dropout = dropout
|
| 71 |
+
# no future time stamps available
|
| 72 |
+
if use_causal_conv:
|
| 73 |
+
padding = (kernel_size - 1) * dilation
|
| 74 |
+
else:
|
| 75 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
| 76 |
+
padding = (kernel_size - 1) // 2 * dilation
|
| 77 |
+
self.use_causal_conv = use_causal_conv
|
| 78 |
+
|
| 79 |
+
# dilation conv
|
| 80 |
+
self.conv = Conv1d(
|
| 81 |
+
residual_channels,
|
| 82 |
+
gate_channels,
|
| 83 |
+
kernel_size,
|
| 84 |
+
padding=padding,
|
| 85 |
+
dilation=dilation,
|
| 86 |
+
bias=bias,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# local conditioning
|
| 90 |
+
if aux_channels > 0:
|
| 91 |
+
self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
|
| 92 |
+
else:
|
| 93 |
+
self.conv1x1_aux = None
|
| 94 |
+
|
| 95 |
+
# conv output is split into two groups
|
| 96 |
+
gate_out_channels = gate_channels // 2
|
| 97 |
+
self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
|
| 98 |
+
self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
|
| 99 |
+
|
| 100 |
+
def forward(self, x, c):
|
| 101 |
+
"""Calculate forward propagation.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
x (Tensor): Input tensor (B, residual_channels, T).
|
| 105 |
+
c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Tensor: Output tensor for residual connection (B, residual_channels, T).
|
| 109 |
+
Tensor: Output tensor for skip connection (B, skip_channels, T).
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
residual = x
|
| 113 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 114 |
+
x = self.conv(x)
|
| 115 |
+
|
| 116 |
+
# remove future time steps if use_causal_conv conv
|
| 117 |
+
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
|
| 118 |
+
|
| 119 |
+
# split into two part for gated activation
|
| 120 |
+
splitdim = 1
|
| 121 |
+
xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
|
| 122 |
+
|
| 123 |
+
# local conditioning
|
| 124 |
+
if c is not None:
|
| 125 |
+
assert self.conv1x1_aux is not None
|
| 126 |
+
c = self.conv1x1_aux(c)
|
| 127 |
+
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
|
| 128 |
+
xa, xb = xa + ca, xb + cb
|
| 129 |
+
|
| 130 |
+
x = torch.tanh(xa) * torch.sigmoid(xb)
|
| 131 |
+
|
| 132 |
+
# for skip connection
|
| 133 |
+
s = self.conv1x1_skip(x)
|
| 134 |
+
|
| 135 |
+
# for residual connection
|
| 136 |
+
x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
|
| 137 |
+
|
| 138 |
+
return x, s
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class HiFiGANResidualBlock(torch.nn.Module):
|
| 142 |
+
"""Residual block module in HiFiGAN."""
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
kernel_size=3,
|
| 147 |
+
channels=512,
|
| 148 |
+
dilations=(1, 3, 5),
|
| 149 |
+
bias=True,
|
| 150 |
+
use_additional_convs=True,
|
| 151 |
+
nonlinear_activation="LeakyReLU",
|
| 152 |
+
nonlinear_activation_params={"negative_slope": 0.1},
|
| 153 |
+
):
|
| 154 |
+
"""Initialize HiFiGANResidualBlock module.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
| 158 |
+
channels (int): Number of channels for convolution layer.
|
| 159 |
+
dilations (List[int]): List of dilation factors.
|
| 160 |
+
use_additional_convs (bool): Whether to use additional convolution layers.
|
| 161 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
| 162 |
+
nonlinear_activation (str): Activation function module name.
|
| 163 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
| 164 |
+
|
| 165 |
+
"""
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.use_additional_convs = use_additional_convs
|
| 168 |
+
self.convs1 = torch.nn.ModuleList()
|
| 169 |
+
if use_additional_convs:
|
| 170 |
+
self.convs2 = torch.nn.ModuleList()
|
| 171 |
+
assert kernel_size % 2 == 1, "Kernel size must be odd number."
|
| 172 |
+
for dilation in dilations:
|
| 173 |
+
self.convs1 += [
|
| 174 |
+
torch.nn.Sequential(
|
| 175 |
+
getattr(torch.nn, nonlinear_activation)(
|
| 176 |
+
**nonlinear_activation_params
|
| 177 |
+
),
|
| 178 |
+
torch.nn.Conv1d(
|
| 179 |
+
channels,
|
| 180 |
+
channels,
|
| 181 |
+
kernel_size,
|
| 182 |
+
1,
|
| 183 |
+
dilation=dilation,
|
| 184 |
+
bias=bias,
|
| 185 |
+
padding=(kernel_size - 1) // 2 * dilation,
|
| 186 |
+
),
|
| 187 |
+
)
|
| 188 |
+
]
|
| 189 |
+
if use_additional_convs:
|
| 190 |
+
self.convs2 += [
|
| 191 |
+
torch.nn.Sequential(
|
| 192 |
+
getattr(torch.nn, nonlinear_activation)(
|
| 193 |
+
**nonlinear_activation_params
|
| 194 |
+
),
|
| 195 |
+
torch.nn.Conv1d(
|
| 196 |
+
channels,
|
| 197 |
+
channels,
|
| 198 |
+
kernel_size,
|
| 199 |
+
1,
|
| 200 |
+
dilation=1,
|
| 201 |
+
bias=bias,
|
| 202 |
+
padding=(kernel_size - 1) // 2,
|
| 203 |
+
),
|
| 204 |
+
)
|
| 205 |
+
]
|
| 206 |
+
|
| 207 |
+
def forward(self, x):
|
| 208 |
+
"""Calculate forward propagation.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
x (Tensor): Input tensor (B, channels, T).
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
Tensor: Output tensor (B, channels, T).
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
for idx in range(len(self.convs1)):
|
| 218 |
+
xt = self.convs1[idx](x)
|
| 219 |
+
if self.use_additional_convs:
|
| 220 |
+
xt = self.convs2[idx](xt)
|
| 221 |
+
x = xt + x
|
| 222 |
+
return x
|
vec2wav2/layers/residual_stack.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
# Copyright 2020 Tomoki Hayashi
|
| 4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 5 |
+
|
| 6 |
+
"""Residual stack module in MelGAN."""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from vec2wav2.layers import CausalConv1d
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ResidualStack(torch.nn.Module):
|
| 14 |
+
"""Residual stack module introduced in MelGAN."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
kernel_size=3,
|
| 19 |
+
channels=32,
|
| 20 |
+
dilation=1,
|
| 21 |
+
bias=True,
|
| 22 |
+
nonlinear_activation="LeakyReLU",
|
| 23 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
| 24 |
+
pad="ReflectionPad1d",
|
| 25 |
+
pad_params={},
|
| 26 |
+
use_causal_conv=False,
|
| 27 |
+
):
|
| 28 |
+
"""Initialize ResidualStack module.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
kernel_size (int): Kernel size of dilation convolution layer.
|
| 32 |
+
channels (int): Number of channels of convolution layers.
|
| 33 |
+
dilation (int): Dilation factor.
|
| 34 |
+
bias (bool): Whether to add bias parameter in convolution layers.
|
| 35 |
+
nonlinear_activation (str): Activation function module name.
|
| 36 |
+
nonlinear_activation_params (dict): Hyperparameters for activation function.
|
| 37 |
+
pad (str): Padding function module name before dilated convolution layer.
|
| 38 |
+
pad_params (dict): Hyperparameters for padding function.
|
| 39 |
+
use_causal_conv (bool): Whether to use causal convolution.
|
| 40 |
+
|
| 41 |
+
"""
|
| 42 |
+
super(ResidualStack, self).__init__()
|
| 43 |
+
|
| 44 |
+
# defile residual stack part
|
| 45 |
+
if not use_causal_conv:
|
| 46 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
| 47 |
+
self.stack = torch.nn.Sequential(
|
| 48 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
| 49 |
+
getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
|
| 50 |
+
torch.nn.Conv1d(
|
| 51 |
+
channels, channels, kernel_size, dilation=dilation, bias=bias
|
| 52 |
+
),
|
| 53 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
| 54 |
+
torch.nn.Conv1d(channels, channels, 1, bias=bias),
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
self.stack = torch.nn.Sequential(
|
| 58 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
| 59 |
+
CausalConv1d(
|
| 60 |
+
channels,
|
| 61 |
+
channels,
|
| 62 |
+
kernel_size,
|
| 63 |
+
dilation=dilation,
|
| 64 |
+
bias=bias,
|
| 65 |
+
pad=pad,
|
| 66 |
+
pad_params=pad_params,
|
| 67 |
+
),
|
| 68 |
+
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
|
| 69 |
+
torch.nn.Conv1d(channels, channels, 1, bias=bias),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# defile extra layer for skip connection
|
| 73 |
+
self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
|
| 74 |
+
|
| 75 |
+
def forward(self, c):
|
| 76 |
+
"""Calculate forward propagation.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
c (Tensor): Input tensor (B, channels, T).
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Tensor: Output tensor (B, chennels, T).
|
| 83 |
+
|
| 84 |
+
"""
|
| 85 |
+
return self.stack(c) + self.skip_layer(c)
|
vec2wav2/layers/tade_res_block.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 Tomoki Hayashi
|
| 2 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
| 3 |
+
|
| 4 |
+
"""StyleMelGAN's TADEResBlock Modules."""
|
| 5 |
+
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TADELayer(torch.nn.Module):
|
| 12 |
+
"""TADE Layer module."""
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_channels=64,
|
| 17 |
+
aux_channels=80,
|
| 18 |
+
kernel_size=9,
|
| 19 |
+
bias=True,
|
| 20 |
+
upsample_factor=2,
|
| 21 |
+
upsample_mode="nearest",
|
| 22 |
+
):
|
| 23 |
+
"""Initilize TADE layer."""
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.norm = torch.nn.InstanceNorm1d(in_channels)
|
| 26 |
+
self.aux_conv = torch.nn.Sequential(
|
| 27 |
+
torch.nn.Conv1d(
|
| 28 |
+
aux_channels,
|
| 29 |
+
in_channels,
|
| 30 |
+
kernel_size,
|
| 31 |
+
1,
|
| 32 |
+
bias=bias,
|
| 33 |
+
padding=(kernel_size - 1) // 2,
|
| 34 |
+
),
|
| 35 |
+
# NOTE(kan-bayashi): Use non-linear activation?
|
| 36 |
+
)
|
| 37 |
+
self.gated_conv = torch.nn.Sequential(
|
| 38 |
+
torch.nn.Conv1d(
|
| 39 |
+
in_channels,
|
| 40 |
+
in_channels * 2,
|
| 41 |
+
kernel_size,
|
| 42 |
+
1,
|
| 43 |
+
bias=bias,
|
| 44 |
+
padding=(kernel_size - 1) // 2,
|
| 45 |
+
),
|
| 46 |
+
# NOTE(kan-bayashi): Use non-linear activation?
|
| 47 |
+
)
|
| 48 |
+
self.upsample = torch.nn.Upsample(
|
| 49 |
+
scale_factor=upsample_factor, mode=upsample_mode
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x, c):
|
| 53 |
+
"""Calculate forward propagation.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
x (Tensor): Input tensor (B, in_channels, T).
|
| 57 |
+
c (Tensor): Auxiliary input tensor (B, aux_channels, T').
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Tensor: Output tensor (B, in_channels, T * in_upsample_factor).
|
| 61 |
+
Tensor: Upsampled aux tensor (B, in_channels, T * aux_upsample_factor).
|
| 62 |
+
|
| 63 |
+
"""
|
| 64 |
+
x = self.norm(x)
|
| 65 |
+
c = self.upsample(c)
|
| 66 |
+
c = self.aux_conv(c)
|
| 67 |
+
cg = self.gated_conv(c)
|
| 68 |
+
cg1, cg2 = cg.split(cg.size(1) // 2, dim=1)
|
| 69 |
+
# NOTE(kan-bayashi): Use upsample for noise input here?
|
| 70 |
+
y = cg1 * self.upsample(x) + cg2
|
| 71 |
+
# NOTE(kan-bayashi): Return upsampled aux here?
|
| 72 |
+
return y, c
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TADEResBlock(torch.nn.Module):
|
| 76 |
+
"""TADEResBlock module."""
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
in_channels=64,
|
| 81 |
+
aux_channels=80,
|
| 82 |
+
kernel_size=9,
|
| 83 |
+
dilation=2,
|
| 84 |
+
bias=True,
|
| 85 |
+
upsample_factor=2,
|
| 86 |
+
upsample_mode="nearest",
|
| 87 |
+
gated_function="softmax",
|
| 88 |
+
):
|
| 89 |
+
"""Initialize TADEResBlock module."""
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.tade1 = TADELayer(
|
| 92 |
+
in_channels=in_channels,
|
| 93 |
+
aux_channels=aux_channels,
|
| 94 |
+
kernel_size=kernel_size,
|
| 95 |
+
bias=bias,
|
| 96 |
+
# NOTE(kan-bayashi): Use upsample in the first TADE layer?
|
| 97 |
+
upsample_factor=1,
|
| 98 |
+
upsample_mode=upsample_mode,
|
| 99 |
+
)
|
| 100 |
+
self.gated_conv1 = torch.nn.Conv1d(
|
| 101 |
+
in_channels,
|
| 102 |
+
in_channels * 2,
|
| 103 |
+
kernel_size,
|
| 104 |
+
1,
|
| 105 |
+
bias=bias,
|
| 106 |
+
padding=(kernel_size - 1) // 2,
|
| 107 |
+
)
|
| 108 |
+
self.tade2 = TADELayer(
|
| 109 |
+
in_channels=in_channels,
|
| 110 |
+
aux_channels=in_channels,
|
| 111 |
+
kernel_size=kernel_size,
|
| 112 |
+
bias=bias,
|
| 113 |
+
upsample_factor=upsample_factor,
|
| 114 |
+
upsample_mode=upsample_mode,
|
| 115 |
+
)
|
| 116 |
+
self.gated_conv2 = torch.nn.Conv1d(
|
| 117 |
+
in_channels,
|
| 118 |
+
in_channels * 2,
|
| 119 |
+
kernel_size,
|
| 120 |
+
1,
|
| 121 |
+
bias=bias,
|
| 122 |
+
dilation=dilation,
|
| 123 |
+
padding=(kernel_size - 1) // 2 * dilation,
|
| 124 |
+
)
|
| 125 |
+
self.upsample = torch.nn.Upsample(
|
| 126 |
+
scale_factor=upsample_factor, mode=upsample_mode
|
| 127 |
+
)
|
| 128 |
+
if gated_function == "softmax":
|
| 129 |
+
self.gated_function = partial(torch.softmax, dim=1)
|
| 130 |
+
elif gated_function == "sigmoid":
|
| 131 |
+
self.gated_function = torch.sigmoid
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f"{gated_function} is not supported.")
|
| 134 |
+
|
| 135 |
+
def forward(self, x, c):
|
| 136 |
+
"""Calculate forward propagation.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
x (Tensor): Input tensor (B, in_channels, T).
|
| 140 |
+
c (Tensor): Auxiliary input tensor (B, aux_channels, T').
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Tensor: Output tensor (B, in_channels, T * in_upsample_factor).
|
| 144 |
+
Tensor: Upsampled auxirialy tensor (B, in_channels, T * in_upsample_factor).
|
| 145 |
+
|
| 146 |
+
"""
|
| 147 |
+
residual = x
|
| 148 |
+
|
| 149 |
+
x, c = self.tade1(x, c)
|
| 150 |
+
x = self.gated_conv1(x)
|
| 151 |
+
xa, xb = x.split(x.size(1) // 2, dim=1)
|
| 152 |
+
x = self.gated_function(xa) * torch.tanh(xb)
|
| 153 |
+
|
| 154 |
+
x, c = self.tade2(x, c)
|
| 155 |
+
x = self.gated_conv2(x)
|
| 156 |
+
xa, xb = x.split(x.size(1) // 2, dim=1)
|
| 157 |
+
x = self.gated_function(xa) * torch.tanh(xb)
|
| 158 |
+
|
| 159 |
+
# NOTE(kan-bayashi): Return upsampled aux here?
|
| 160 |
+
return self.upsample(residual) + x, c
|
vec2wav2/layers/upsample.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
"""Upsampling module.
|
| 4 |
+
|
| 5 |
+
This code is modified from https://github.com/r9y9/wavenet_vocoder.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
from vec2wav2.layers import Conv1d
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Stretch2d(torch.nn.Module):
|
| 17 |
+
"""Stretch2d module."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, x_scale, y_scale, mode="nearest"):
|
| 20 |
+
"""Initialize Stretch2d module.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
x_scale (int): X scaling factor (Time axis in spectrogram).
|
| 24 |
+
y_scale (int): Y scaling factor (Frequency axis in spectrogram).
|
| 25 |
+
mode (str): Interpolation mode.
|
| 26 |
+
|
| 27 |
+
"""
|
| 28 |
+
super(Stretch2d, self).__init__()
|
| 29 |
+
self.x_scale = x_scale
|
| 30 |
+
self.y_scale = y_scale
|
| 31 |
+
self.mode = mode
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
"""Calculate forward propagation.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x (Tensor): Input tensor (B, C, F, T).
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
|
| 41 |
+
|
| 42 |
+
"""
|
| 43 |
+
return F.interpolate(
|
| 44 |
+
x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Conv2d(torch.nn.Conv2d):
|
| 49 |
+
"""Conv2d module with customized initialization."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, *args, **kwargs):
|
| 52 |
+
"""Initialize Conv2d module."""
|
| 53 |
+
super(Conv2d, self).__init__(*args, **kwargs)
|
| 54 |
+
|
| 55 |
+
def reset_parameters(self):
|
| 56 |
+
"""Reset parameters."""
|
| 57 |
+
self.weight.data.fill_(1.0 / np.prod(self.kernel_size))
|
| 58 |
+
if self.bias is not None:
|
| 59 |
+
torch.nn.init.constant_(self.bias, 0.0)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class UpsampleNetwork(torch.nn.Module):
|
| 63 |
+
"""Upsampling network module."""
|
| 64 |
+
|
| 65 |
+
def __init__(
|
| 66 |
+
self,
|
| 67 |
+
upsample_scales,
|
| 68 |
+
nonlinear_activation=None,
|
| 69 |
+
nonlinear_activation_params={},
|
| 70 |
+
interpolate_mode="nearest",
|
| 71 |
+
freq_axis_kernel_size=1,
|
| 72 |
+
use_causal_conv=False,
|
| 73 |
+
):
|
| 74 |
+
"""Initialize upsampling network module.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
upsample_scales (list): List of upsampling scales.
|
| 78 |
+
nonlinear_activation (str): Activation function name.
|
| 79 |
+
nonlinear_activation_params (dict): Arguments for specified activation function.
|
| 80 |
+
interpolate_mode (str): Interpolation mode.
|
| 81 |
+
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
| 82 |
+
|
| 83 |
+
"""
|
| 84 |
+
super(UpsampleNetwork, self).__init__()
|
| 85 |
+
self.use_causal_conv = use_causal_conv
|
| 86 |
+
self.up_layers = torch.nn.ModuleList()
|
| 87 |
+
for scale in upsample_scales:
|
| 88 |
+
# interpolation layer
|
| 89 |
+
stretch = Stretch2d(scale, 1, interpolate_mode)
|
| 90 |
+
self.up_layers += [stretch]
|
| 91 |
+
|
| 92 |
+
# conv layer
|
| 93 |
+
assert (
|
| 94 |
+
freq_axis_kernel_size - 1
|
| 95 |
+
) % 2 == 0, "Not support even number freq axis kernel size."
|
| 96 |
+
freq_axis_padding = (freq_axis_kernel_size - 1) // 2
|
| 97 |
+
kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
|
| 98 |
+
if use_causal_conv:
|
| 99 |
+
padding = (freq_axis_padding, scale * 2)
|
| 100 |
+
else:
|
| 101 |
+
padding = (freq_axis_padding, scale)
|
| 102 |
+
conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
|
| 103 |
+
self.up_layers += [conv]
|
| 104 |
+
|
| 105 |
+
# nonlinear
|
| 106 |
+
if nonlinear_activation is not None:
|
| 107 |
+
nonlinear = getattr(torch.nn, nonlinear_activation)(
|
| 108 |
+
**nonlinear_activation_params
|
| 109 |
+
)
|
| 110 |
+
self.up_layers += [nonlinear]
|
| 111 |
+
|
| 112 |
+
def forward(self, c):
|
| 113 |
+
"""Calculate forward propagation.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
c : Input tensor (B, C, T).
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
|
| 120 |
+
|
| 121 |
+
"""
|
| 122 |
+
c = c.unsqueeze(1) # (B, 1, C, T)
|
| 123 |
+
for f in self.up_layers:
|
| 124 |
+
if self.use_causal_conv and isinstance(f, Conv2d):
|
| 125 |
+
c = f(c)[..., : c.size(-1)]
|
| 126 |
+
else:
|
| 127 |
+
c = f(c)
|
| 128 |
+
return c.squeeze(1) # (B, C, T')
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class ConvInUpsampleNetwork(torch.nn.Module):
|
| 132 |
+
"""Convolution + upsampling network module."""
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
upsample_scales,
|
| 137 |
+
nonlinear_activation=None,
|
| 138 |
+
nonlinear_activation_params={},
|
| 139 |
+
interpolate_mode="nearest",
|
| 140 |
+
freq_axis_kernel_size=1,
|
| 141 |
+
aux_channels=80,
|
| 142 |
+
aux_context_window=0,
|
| 143 |
+
use_causal_conv=False,
|
| 144 |
+
):
|
| 145 |
+
"""Initialize convolution + upsampling network module.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
upsample_scales (list): List of upsampling scales.
|
| 149 |
+
nonlinear_activation (str): Activation function name.
|
| 150 |
+
nonlinear_activation_params (dict): Arguments for specified activation function.
|
| 151 |
+
mode (str): Interpolation mode.
|
| 152 |
+
freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
|
| 153 |
+
aux_channels (int): Number of channels of pre-convolutional layer.
|
| 154 |
+
aux_context_window (int): Context window size of the pre-convolutional layer.
|
| 155 |
+
use_causal_conv (bool): Whether to use causal structure.
|
| 156 |
+
|
| 157 |
+
"""
|
| 158 |
+
super(ConvInUpsampleNetwork, self).__init__()
|
| 159 |
+
self.aux_context_window = aux_context_window
|
| 160 |
+
self.use_causal_conv = use_causal_conv and aux_context_window > 0
|
| 161 |
+
# To capture wide-context information in conditional features
|
| 162 |
+
kernel_size = (
|
| 163 |
+
aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
|
| 164 |
+
)
|
| 165 |
+
# NOTE(kan-bayashi): Here do not use padding because the input is already padded
|
| 166 |
+
self.conv_in = Conv1d(
|
| 167 |
+
aux_channels, aux_channels, kernel_size=kernel_size, bias=False
|
| 168 |
+
)
|
| 169 |
+
self.upsample = UpsampleNetwork(
|
| 170 |
+
upsample_scales=upsample_scales,
|
| 171 |
+
nonlinear_activation=nonlinear_activation,
|
| 172 |
+
nonlinear_activation_params=nonlinear_activation_params,
|
| 173 |
+
interpolate_mode=interpolate_mode,
|
| 174 |
+
freq_axis_kernel_size=freq_axis_kernel_size,
|
| 175 |
+
use_causal_conv=use_causal_conv,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def forward(self, c):
|
| 179 |
+
"""Calculate forward propagation.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
c : Input tensor (B, C, T').
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
Tensor: Upsampled tensor (B, C, T),
|
| 186 |
+
where T = (T' - aux_context_window * 2) * prod(upsample_scales).
|
| 187 |
+
|
| 188 |
+
Note:
|
| 189 |
+
The length of inputs considers the context window size.
|
| 190 |
+
|
| 191 |
+
"""
|
| 192 |
+
c_ = self.conv_in(c)
|
| 193 |
+
c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_
|
| 194 |
+
return self.upsample(c)
|
vec2wav2/losses/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .adversarial_loss import * # NOQA
|
| 2 |
+
from .feat_match_loss import * # NOQA
|
| 3 |
+
from .mel_loss import * # NOQA
|
| 4 |
+
from .stft_loss import * # NOQA
|