| import copy |
| import logging |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.utils import weight_norm, spectral_norm |
| from einops import rearrange |
|
|
|
|
| class HiFiGANPeriodDiscriminator(torch.nn.Module): |
| """HiFiGAN period discriminator module.""" |
| def __init__( |
| self, |
| in_channels=1, |
| out_channels=1, |
| period=3, |
| kernel_sizes=[5, 3], |
| channels=32, |
| downsample_scales=[3, 3, 3, 3, 1], |
| channel_increasing_factor=4, |
| max_downsample_channels=1024, |
| nonlinear_activation="LeakyReLU", |
| nonlinear_activation_params={"negative_slope": 0.1}, |
| use_weight_norm=True, |
| ): |
| """Initialize HiFiGANPeriodDiscriminator module. |
| Args: |
| in_channels (int): Number of input channels. |
| out_channels (int): Number of output channels. |
| period (int): Period. |
| kernel_sizes (list): Kernel sizes of initial conv layers and the final conv layer. |
| channels (int): Number of initial channels. |
| downsample_scales (list): List of downsampling scales. |
| max_downsample_channels (int): Number of maximum downsampling channels. |
| nonlinear_activation (str): Activation function module name. |
| nonlinear_activation_params (dict): Hyperparameters for activation function. |
| use_weight_norm (bool): Whether to use weight norm. |
| If set to true, it will be applied to all of the conv layers. |
| """ |
| super().__init__() |
| assert len(kernel_sizes) == 2 |
| assert kernel_sizes[0] % 2 == 1, "Kernel size must be odd number." |
| assert kernel_sizes[1] % 2 == 1, "Kernel size must be odd number." |
|
|
| self.period = period |
| self.convs = torch.nn.ModuleList() |
| in_chs = in_channels |
| out_chs = channels |
| for downsample_scale in downsample_scales: |
| self.convs += [ |
| torch.nn.Sequential( |
| torch.nn.Conv2d( |
| in_chs, |
| out_chs, |
| (kernel_sizes[0], 1), |
| (downsample_scale, 1), |
| padding=((kernel_sizes[0] - 1) // 2, 0), |
| ), |
| getattr(torch.nn, nonlinear_activation)( |
| **nonlinear_activation_params |
| ), |
| ) |
| ] |
| in_chs = out_chs |
| out_chs = min(out_chs * channel_increasing_factor, max_downsample_channels) |
| self.output_conv = torch.nn.Conv2d( |
| in_chs, |
| out_channels, |
| (kernel_sizes[1] - 1, 1), |
| 1, |
| padding=((kernel_sizes[1] - 1) // 2, 0), |
| ) |
|
|
| if use_weight_norm: |
| self.apply_weight_norm() |
|
|
| def forward(self, x): |
| """Calculate forward propagation. |
| Args: |
| c (Tensor): Input tensor (B, in_channels, T). |
| Returns: |
| list: List of each layer's tensors. |
| """ |
| b, c, t = x.shape |
| if t % self.period != 0: |
| n_pad = self.period - (t % self.period) |
| x = F.pad(x, (0, n_pad), "reflect") |
| t += n_pad |
| x = x.view(b, c, t // self.period, self.period) |
|
|
| outs = [] |
| for layer in self.convs: |
| x = layer(x) |
| outs += [x] |
| x = self.output_conv(x) |
| x = torch.flatten(x, 1, -1) |
| outs += [x] |
|
|
| return outs |
|
|
| def apply_weight_norm(self): |
| def _apply_weight_norm(m): |
| if isinstance(m, torch.nn.Conv2d): |
| torch.nn.utils.weight_norm(m) |
| |
| self.apply(_apply_weight_norm) |
|
|
|
|
| class HiFiGANMultiPeriodDiscriminator(torch.nn.Module): |
| def __init__( |
| self, |
| periods=[2, 3, 5, 7, 11], |
| **kwargs, |
| ): |
| """Initialize HiFiGANMultiPeriodDiscriminator module. |
| Args: |
| periods (list): List of periods. |
| discriminator_params (dict): Parameters for hifi-gan period discriminator module. |
| The period parameter will be overwritten. |
| """ |
| super().__init__() |
| self.discriminators = torch.nn.ModuleList() |
| for period in periods: |
| params = copy.deepcopy(kwargs) |
| params["period"] = period |
| self.discriminators += [HiFiGANPeriodDiscriminator(**params)] |
|
|
| def forward(self, x): |
| """Calculate forward propagation. |
| Args: |
| x (Tensor): Input noise signal (B, 1, T). |
| Returns: |
| List: List of list of each discriminator outputs, which consists of each layer output tensors. |
| """ |
| outs = [] |
| for f in self.discriminators: |
| outs += [f(x)] |
|
|
| return outs |
|
|