Spaces:
Build error
Build error
| import os | |
| import sys | |
| sys.path.insert(1, os.path.join(sys.path[0], '../utils')) | |
| import numpy as np | |
| import argparse | |
| import h5py | |
| import math | |
| import time | |
| import logging | |
| import matplotlib.pyplot as plt | |
| import torch | |
| torch.backends.cudnn.benchmark=True | |
| torch.manual_seed(0) | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import torch.utils.data | |
| from utilities import get_filename | |
| from models import * | |
| import config | |
| class Transfer_Cnn14(nn.Module): | |
| def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, | |
| fmax, classes_num, freeze_base): | |
| """Classifier for a new task using pretrained Cnn14 as a sub module. | |
| """ | |
| super(Transfer_Cnn14, self).__init__() | |
| audioset_classes_num = 527 | |
| self.base = Cnn14(sample_rate, window_size, hop_size, mel_bins, fmin, | |
| fmax, audioset_classes_num) | |
| # Transfer to another task layer | |
| self.fc_transfer = nn.Linear(2048, classes_num, bias=True) | |
| if freeze_base: | |
| # Freeze AudioSet pretrained layers | |
| for param in self.base.parameters(): | |
| param.requires_grad = False | |
| self.init_weights() | |
| def init_weights(self): | |
| init_layer(self.fc_transfer) | |
| def load_from_pretrain(self, pretrained_checkpoint_path): | |
| checkpoint = torch.load(pretrained_checkpoint_path) | |
| self.base.load_state_dict(checkpoint['model']) | |
| def forward(self, input, mixup_lambda=None): | |
| """Input: (batch_size, data_length) | |
| """ | |
| output_dict = self.base(input, mixup_lambda) | |
| embedding = output_dict['embedding'] | |
| clipwise_output = torch.log_softmax(self.fc_transfer(embedding), dim=-1) | |
| output_dict['clipwise_output'] = clipwise_output | |
| return output_dict | |
| def train(args): | |
| # Arugments & parameters | |
| sample_rate = args.sample_rate | |
| window_size = args.window_size | |
| hop_size = args.hop_size | |
| mel_bins = args.mel_bins | |
| fmin = args.fmin | |
| fmax = args.fmax | |
| model_type = args.model_type | |
| pretrained_checkpoint_path = args.pretrained_checkpoint_path | |
| freeze_base = args.freeze_base | |
| device = 'cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu' | |
| classes_num = config.classes_num | |
| pretrain = True if pretrained_checkpoint_path else False | |
| # Model | |
| Model = eval(model_type) | |
| model = Model(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, | |
| classes_num, freeze_base) | |
| # Load pretrained model | |
| if pretrain: | |
| logging.info('Load pretrained model from {}'.format(pretrained_checkpoint_path)) | |
| model.load_from_pretrain(pretrained_checkpoint_path) | |
| # Parallel | |
| print('GPU number: {}'.format(torch.cuda.device_count())) | |
| model = torch.nn.DataParallel(model) | |
| if 'cuda' in device: | |
| model.to(device) | |
| print('Load pretrained model successfully!') | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Example of parser. ') | |
| subparsers = parser.add_subparsers(dest='mode') | |
| # Train | |
| parser_train = subparsers.add_parser('train') | |
| parser_train.add_argument('--sample_rate', type=int, required=True) | |
| parser_train.add_argument('--window_size', type=int, required=True) | |
| parser_train.add_argument('--hop_size', type=int, required=True) | |
| parser_train.add_argument('--mel_bins', type=int, required=True) | |
| parser_train.add_argument('--fmin', type=int, required=True) | |
| parser_train.add_argument('--fmax', type=int, required=True) | |
| parser_train.add_argument('--model_type', type=str, required=True) | |
| parser_train.add_argument('--pretrained_checkpoint_path', type=str) | |
| parser_train.add_argument('--freeze_base', action='store_true', default=False) | |
| parser_train.add_argument('--cuda', action='store_true', default=False) | |
| # Parse arguments | |
| args = parser.parse_args() | |
| args.filename = get_filename(__file__) | |
| if args.mode == 'train': | |
| train(args) | |
| else: | |
| raise Exception('Error argument!') |