Spaces:
Build error
Build error
| import sys | |
| sys.path.append("src") | |
| import torch | |
| import logging | |
| import torch.nn as nn | |
| from qa_mdt.audioldm_train.modules.clap.open_clip import create_model | |
| from qa_mdt.audioldm_train.modules.clap.training.data import get_audio_features | |
| import torchaudio | |
| from transformers import ( | |
| RobertaTokenizer, | |
| AutoTokenizer, | |
| T5EncoderModel, | |
| MT5EncoderModel, | |
| ) | |
| import torch.nn.functional as F | |
| from qa_mdt.audioldm_train.modules.audiomae.AudioMAE import Vanilla_AudioMAE | |
| from qa_mdt.audioldm_train.modules.phoneme_encoder.encoder import TextEncoder | |
| from transformers import SpeechT5Processor, AutoTokenizer, GPT2Model, GPT2Tokenizer | |
| from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithTextPrenet | |
| from qa_mdt.audioldm_train.modules.audiomae.sequence_gen.model import CLAP2AudioMAE | |
| from qa_mdt.audioldm_train.modules.audiomae.sequence_gen.sequence_input import ( | |
| Sequence2AudioMAE, | |
| ) | |
| import numpy as np | |
| from qa_mdt.audioldm_train.modules.audiomae.sequence_gen.model import Prenet | |
| import json | |
| with open('./qa_mdt/offset_pretrained_checkpoints.json', 'r') as config_file: | |
| config_data = json.load(config_file) | |
| """ | |
| The model forward function can return three types of data: | |
| 1. tensor: used directly as conditioning signal | |
| 2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc. | |
| 3. list: the length is 2, in which the first element is tensor, the second element is attntion mask. | |
| The output shape for the cross attention condition should be: | |
| x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len] | |
| All the returned data, in which will be used as diffusion input, will need to be in float type | |
| """ | |
| class GPT2WordEmbedding(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # self.tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model = GPT2Model.from_pretrained("gpt2").wte | |
| self.device = None | |
| def get_unconditional_condition(self, batchsize): | |
| unconditional_condition = ["random"] * batchsize | |
| return self(unconditional_condition) | |
| def forward(self, text): | |
| assert isinstance(text, list) | |
| if self.device is None: | |
| self.device = next(self.model.parameters()).device | |
| tokenization_result = self.tokenizer(text, return_tensors="pt", padding=True) | |
| input_ids, attn_mask = tokenization_result["input_ids"].to( | |
| self.device | |
| ), tokenization_result["attention_mask"].to(self.device) | |
| input_embed = self.model(input_ids.long()) | |
| return [input_embed, attn_mask] | |
| class ConcateBandWidthCond(nn.Module): | |
| def __init__(self, latent_t_size, latent_f_size): | |
| super().__init__() | |
| self.placeholder = nn.Linear(1, 1) | |
| self.latent_t_size = latent_t_size | |
| self.latent_f_size = latent_f_size | |
| self.device = None | |
| def get_unconditional_condition(self, batchsize): | |
| return torch.zeros((batchsize, self.latent_t_size, self.latent_f_size)).to( | |
| self.device | |
| ) | |
| def forward(self, mel_spec_bandwidth_cond_extra_channel): | |
| if self.device is None: | |
| self.device = mel_spec_bandwidth_cond_extra_channel.device | |
| return mel_spec_bandwidth_cond_extra_channel | |
| class BandwidthEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.emb = nn.Embedding(1000, 128) | |
| nn.init.normal_(self.emb.weight, 0.0, 128**-0.5) | |
| self.linear_bandwidth = nn.Linear(128, 128) | |
| self.unconditional_condition = torch.zeros((1, 256)) | |
| self.device = None | |
| def get_unconditional_condition(self, batchsize): | |
| return self.unconditional_condition.expand(batchsize, 256) | |
| def forward(self, bandwidth): | |
| if self.device is None: | |
| self.device = next(self.linear_bandwidth.parameters()).device | |
| self.unconditional_condition = self.unconditional_condition.to(self.device) | |
| # freq_energy_percentile | |
| lower_cutoff, higher_cutoff = bandwidth[..., 0], bandwidth[..., 1] | |
| # lower_cutoff, higher_cutoff = lower_cutoff*0+5, higher_cutoff*0+300 | |
| lower_cutoff_emb = self.linear_bandwidth(self.emb(lower_cutoff.long())) | |
| higher_cutoff_emb = self.linear_bandwidth(self.emb(higher_cutoff.long())) | |
| cutoff_emb = torch.cat([lower_cutoff_emb, higher_cutoff_emb], dim=-1) | |
| # [bs, 256] | |
| return cutoff_emb | |
| class SpeechT5TextEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| self.model = SpeechT5EncoderWithTextPrenet.from_pretrained( | |
| "microsoft/speecht5_tts" | |
| ) | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| self.model.eval() | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| device = self.model.device | |
| hidden_state = torch.zeros((batchsize, 1, 768)).to(device) | |
| attention_mask = torch.ones((batchsize, 1)).to(device) | |
| return [hidden_state.float(), attention_mask.float()] | |
| def forward(self, text): | |
| with torch.no_grad(): | |
| device = self.model.device | |
| inputs = self.processor(text=text, return_tensors="pt", padding=True) | |
| input_ids, attention_mask = inputs["input_ids"].to(device), inputs[ | |
| "attention_mask" | |
| ].to(device) | |
| emb = self.model(input_ids, attention_mask) | |
| emb = emb.last_hidden_state.detach() | |
| return [emb.float(), attention_mask.float()] | |
| class PhonemeEncoder(nn.Module): | |
| def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None): | |
| super().__init__() | |
| """ | |
| encoder = PhonemeEncoder(40) | |
| data = torch.randint(0, 39, (2, 250)) | |
| output = encoder(data) | |
| import ipdb;ipdb.set_trace() | |
| """ | |
| assert pad_token_id is not None | |
| self.device = None | |
| self.PAD_LENGTH = int(pad_length) | |
| self.pad_token_id = pad_token_id | |
| self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH) | |
| self.text_encoder = TextEncoder( | |
| n_vocab=vocabs_size, | |
| out_channels=192, | |
| hidden_channels=192, | |
| filter_channels=768, | |
| n_heads=2, | |
| n_layers=6, | |
| kernel_size=3, | |
| p_dropout=0.1, | |
| ) | |
| self.learnable_positional_embedding = torch.nn.Parameter( | |
| torch.zeros((1, 192, self.PAD_LENGTH)) | |
| ) # [batchsize, seqlen, padlen] | |
| self.learnable_positional_embedding.requires_grad = True | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| unconditional_tokens = self.pad_token_sequence.expand( | |
| batchsize, self.PAD_LENGTH | |
| ) | |
| return self(unconditional_tokens) # Need to return float type | |
| # def get_unconditional_condition(self, batchsize): | |
| # hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device) | |
| # attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device) | |
| # return [hidden_state, attention_mask] # Need to return float type | |
| def _get_src_mask(self, phoneme): | |
| src_mask = phoneme != self.pad_token_id | |
| return src_mask | |
| def _get_src_length(self, phoneme): | |
| src_mask = self._get_src_mask(phoneme) | |
| length = torch.sum(src_mask, dim=-1) | |
| return length | |
| # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask): | |
| # # src_length: [bs] | |
| # # text_emb: [bs, 192, pad_length] | |
| # # attention_mask: [bs, pad_length] | |
| # mask = src_length[..., None, None] > 1 | |
| # text_emb = text_emb * mask | |
| # attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0 | |
| # return text_emb, attention_mask | |
| def forward(self, phoneme_idx): | |
| if self.device is None: | |
| self.device = self.learnable_positional_embedding.device | |
| self.pad_token_sequence = self.pad_token_sequence.to(self.device) | |
| src_length = self._get_src_length(phoneme_idx) | |
| text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length) | |
| text_emb = text_emb + self.learnable_positional_embedding | |
| # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask) | |
| return [ | |
| text_emb.permute(0, 2, 1), | |
| text_emb_mask.squeeze(1), | |
| ] # [2, 250, 192], [2, 250] | |
| class FlanT5HiddenState(nn.Module): | |
| """ | |
| llama = FlanT5HiddenState() | |
| data = ["","this is not an empty sentence"] | |
| encoder_hidden_states = llama(data) | |
| import ipdb;ipdb.set_trace() | |
| """ | |
| def __init__( | |
| self, text_encoder_name=config_data['flan_t5'], freeze_text_encoder=True | |
| ): | |
| super().__init__() | |
| self.freeze_text_encoder = freeze_text_encoder | |
| ## MODIFIED | |
| self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") | |
| self.model = T5EncoderModel.from_pretrained("google/flan-t5-large") | |
| if freeze_text_encoder: | |
| self.model.eval() | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| else: | |
| print("=> The text encoder is learnable") | |
| self.empty_hidden_state_cfg = None | |
| self.device = None | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| param = next(self.model.parameters()) | |
| if self.freeze_text_encoder: | |
| assert param.requires_grad == False | |
| # device = param.device | |
| if self.empty_hidden_state_cfg is None: | |
| self.empty_hidden_state_cfg, _ = self([""]) | |
| hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() | |
| attention_mask = ( | |
| torch.ones((batchsize, hidden_state.size(1))) | |
| .to(hidden_state.device) | |
| .float() | |
| ) | |
| return [hidden_state, attention_mask] # Need to return float type | |
| def forward(self, batch): | |
| param = next(self.model.parameters()) | |
| if self.freeze_text_encoder: | |
| assert param.requires_grad == False | |
| if self.device is None: | |
| self.device = param.device | |
| # print("Manually change text") | |
| # for i in range(len(batch)): | |
| # batch[i] = "dog barking" | |
| try: | |
| return self.encode_text(batch) | |
| except Exception as e: | |
| print(e, batch) | |
| logging.exception("An error occurred: %s", str(e)) | |
| def encode_text(self, prompt): | |
| device = self.model.device | |
| batch = self.tokenizer( | |
| prompt, | |
| max_length=128, # self.tokenizer.model_max_length | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( | |
| device | |
| ) | |
| # Get text encoding | |
| if self.freeze_text_encoder: | |
| with torch.no_grad(): | |
| encoder_hidden_states = self.model( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| )[0] | |
| else: | |
| encoder_hidden_states = self.model( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| )[0] | |
| return [ | |
| encoder_hidden_states.detach(), | |
| attention_mask.float(), | |
| ] # Attention mask == 1 means usable token | |
| class FlanT5HiddenStatePaddedSameLength(nn.Module): | |
| """ | |
| llama = FlanT5HiddenState() | |
| data = ["","this is not an empty sentence"] | |
| encoder_hidden_states = llama(data) | |
| import ipdb;ipdb.set_trace() | |
| """ | |
| def __init__( | |
| self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True | |
| ): | |
| super().__init__() | |
| self.freeze_text_encoder = freeze_text_encoder | |
| self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large") | |
| self.model = T5EncoderModel.from_pretrained("google/flan-t5-large") | |
| if freeze_text_encoder: | |
| self.model.eval() | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| else: | |
| print("=> The text encoder is learnable") | |
| self.empty_hidden_state_cfg = None | |
| self.device = None | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| param = next(self.model.parameters()) | |
| if self.freeze_text_encoder: | |
| assert param.requires_grad == False | |
| # device = param.device | |
| if self.empty_hidden_state_cfg is None: | |
| self.empty_hidden_state_cfg, _ = self([""]) | |
| hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float() | |
| attention_mask = ( | |
| torch.ones((batchsize, hidden_state.size(1))) | |
| .to(hidden_state.device) | |
| .float() | |
| ) | |
| return [hidden_state, attention_mask] # Need to return float type | |
| def forward(self, batch): | |
| param = next(self.model.parameters()) | |
| if self.freeze_text_encoder: | |
| assert param.requires_grad == False | |
| if self.device is None: | |
| self.device = param.device | |
| # print("Manually change text") | |
| # for i in range(len(batch)): | |
| # batch[i] = "dog barking" | |
| try: | |
| text_embed = self.encode_text(batch) | |
| return text_embed | |
| except Exception as e: | |
| print(e, batch) | |
| logging.exception("An error occurred: %s", str(e)) | |
| def encode_text(self, prompt): | |
| device = self.model.device | |
| batch = self.tokenizer( | |
| prompt, | |
| max_length=128, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to( | |
| device | |
| ) | |
| # Get text encoding | |
| if self.freeze_text_encoder: | |
| with torch.no_grad(): | |
| encoder_hidden_states = self.model( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| )[0] | |
| else: | |
| encoder_hidden_states = self.model( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| )[0] | |
| return [ | |
| encoder_hidden_states.detach(), | |
| attention_mask.float(), | |
| ] # Attention mask == 1 means usable token | |
| class CLAPGenAudioMAECond(CLAP2AudioMAE): | |
| def __init__( | |
| self, | |
| cond_stage_config, | |
| learnable=True, | |
| pretrained_path=None, | |
| use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT | |
| use_gt_mae_prob=None, | |
| ): # The prob of using AudioMAE GT | |
| super().__init__(base_learning_rate=1e-5, cond_stage_config=cond_stage_config) | |
| assert use_gt_mae_output is not None and use_gt_mae_prob is not None | |
| if pretrained_path is not None: | |
| print("Reload CLAPGenAudioMAECond from %s" % pretrained_path) | |
| state_dict = torch.load(pretrained_path)["state_dict"] | |
| self.load_state_dict(state_dict) | |
| self.use_gt_mae_output = use_gt_mae_output | |
| self.use_gt_mae_prob = use_gt_mae_prob | |
| self.learnable = learnable | |
| if not learnable: | |
| # Only optimize the GPT2 model | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| self.eval() | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| return_dict = self.cfg_uncond(batchsize) | |
| return return_dict | |
| def forward(self, batch): | |
| # The conditional module can return both tensor or dictionaries | |
| # The returned tensor will be corresponding to the cond_stage_key | |
| # The returned dict will have keys that correspond to the cond_stage_key | |
| ret_dict = {} | |
| if self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob: | |
| cond_dict = self.get_input(batch) | |
| # Used as condition | |
| ret_dict["crossattn_clap_to_audiomae_feature"] = [ | |
| cond_dict["crossattn_audiomae_pooled"][0], | |
| torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(), | |
| ] # Input sequence and mask | |
| else: | |
| # Used as condition | |
| input_embeds, cond_dict = self.generate(batch) | |
| input_embeds_mask = ( | |
| torch.ones((input_embeds.size(0), input_embeds.size(1))) | |
| .to(input_embeds.device) | |
| .float() | |
| ) | |
| ret_dict["crossattn_clap_to_audiomae_feature"] = [ | |
| input_embeds, | |
| input_embeds_mask, | |
| ] # Input sequence and mask | |
| # If the following two keys are not in cond_stage_key, then they will not be used as condition | |
| ret_dict["film_clap_cond1"] = cond_dict[ | |
| "film_clap_cond1" | |
| ] # the clap target latent | |
| ret_dict["crossattn_audiomae_pooled"] = cond_dict[ | |
| "crossattn_audiomae_pooled" | |
| ] # audiomae target latent | |
| if self.learnable and self.training: | |
| loss = self.training_step(batch, cond_dict=cond_dict) | |
| ret_dict["noncond_loss_clap2audiomae"] = loss | |
| return ret_dict | |
| class SequenceGenAudioMAECond(Sequence2AudioMAE): | |
| def __init__( | |
| self, | |
| cond_stage_config, | |
| base_learning_rate, | |
| sequence_gen_length, | |
| sequence_input_key, | |
| sequence_input_embed_dim, | |
| batchsize, | |
| always_output_audiomae_gt=False, | |
| pretrained_path=None, | |
| force_reload_pretrain_avoid_overwrite=False, | |
| learnable=True, | |
| use_warmup=True, | |
| use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT | |
| use_gt_mae_prob=None, | |
| ): # The prob of using AudioMAE GT | |
| if use_warmup: | |
| print( | |
| "Warning: You didn't initialize sequence prediction module with trainer. Set warmup to False. You can still use the warmup scheme from the latent diffusion model." | |
| ) | |
| use_warmup = False | |
| super().__init__( | |
| base_learning_rate=base_learning_rate, | |
| cond_stage_config=cond_stage_config, | |
| sequence_gen_length=sequence_gen_length, | |
| sequence_input_key=sequence_input_key, | |
| use_warmup=use_warmup, | |
| sequence_input_embed_dim=sequence_input_embed_dim, | |
| batchsize=batchsize, | |
| ) | |
| assert use_gt_mae_output is not None and use_gt_mae_prob is not None | |
| self.always_output_audiomae_gt = always_output_audiomae_gt | |
| self.force_reload_pretrain_avoid_overwrite = ( | |
| force_reload_pretrain_avoid_overwrite | |
| ) | |
| self.pretrained_path = pretrained_path | |
| if self.force_reload_pretrain_avoid_overwrite: | |
| self.is_reload = False | |
| else: | |
| self.is_reload = True | |
| self.load_pretrain_model() | |
| self.use_gt_mae_output = use_gt_mae_output | |
| self.use_gt_mae_prob = use_gt_mae_prob | |
| self.learnable = learnable | |
| if not learnable: | |
| # Only optimize the GPT2 model | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| self.eval() | |
| def load_pretrain_model(self): | |
| if self.pretrained_path is not None: | |
| print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path) | |
| state_dict = torch.load(self.pretrained_path)["state_dict"] | |
| self.load_state_dict(state_dict) | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| return_dict = self.cfg_uncond(batchsize) | |
| return_dict["crossattn_audiomae_generated"] = [ | |
| return_dict["crossattn_audiomae_pooled"][0], | |
| torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(), | |
| ] | |
| return return_dict | |
| def forward(self, batch): | |
| # The conditional module can return both tensor or dictionaries | |
| # The returned tensor will be corresponding to the cond_stage_key | |
| # The returned dict will have keys that correspond to the cond_stage_key | |
| ret_dict = {} | |
| if self.force_reload_pretrain_avoid_overwrite and not self.is_reload: | |
| self.load_pretrain_model() | |
| self.is_reload = True | |
| self.check_module_param_update() | |
| if self.always_output_audiomae_gt or ( | |
| self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob | |
| ): | |
| cond_dict = self.get_input(batch) | |
| ret_dict["crossattn_audiomae_generated"] = [ | |
| cond_dict["crossattn_audiomae_pooled"][0], | |
| torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(), | |
| ] # Input sequence and mask | |
| # _, output = self.training_step(batch, cond_dict=cond_dict, return_output=True) | |
| # ret_dict["crossattn_audiomae_generated"] = [output, torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float()] # Input sequence and mask | |
| else: | |
| if not self.training: | |
| print("--------------> Generate !!!!!!!!!!!!") | |
| input_embeds, cond_dict = self.generate(batch) | |
| # print("Generate Partial!!!!"); input_embeds, cond_dict = self.generate_partial(batch) | |
| input_embeds_mask = ( | |
| torch.ones((input_embeds.size(0), input_embeds.size(1))) | |
| .to(input_embeds.device) | |
| .float() | |
| ) | |
| ret_dict["crossattn_audiomae_generated"] = [ | |
| input_embeds, | |
| input_embeds_mask, | |
| ] # Input sequence and mask | |
| # If the following two keys are not in cond_stage_key, then they will not be used as condition | |
| for key in cond_dict.keys(): | |
| ret_dict[key] = cond_dict[key] | |
| if self.learnable and self.training: | |
| loss = self.training_step(batch, cond_dict=cond_dict) | |
| ret_dict["noncond_loss_clap2audiomae"] = loss | |
| return ret_dict | |
| class SequenceGenAudioMAECond_AudioMAE_PostNet(Sequence2AudioMAE): | |
| def __init__( | |
| self, | |
| cond_stage_config, | |
| base_learning_rate, | |
| sequence_gen_length, | |
| sequence_input_key, | |
| sequence_input_embed_dim, | |
| batchsize, | |
| always_output_audiomae_gt=False, | |
| pretrained_path=None, | |
| use_ar_gen_loss=False, | |
| force_reload_pretrain_avoid_overwrite=False, | |
| learnable=True, | |
| use_warmup=True, | |
| use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT | |
| use_gt_mae_prob=None, | |
| ): # The prob of using AudioMAE GT | |
| if use_warmup: | |
| print( | |
| "Warning: You didn't initialize sequence prediction module with trainer. Set warmup to False. You can still use the warmup scheme from the latent diffusion model." | |
| ) | |
| use_warmup = False | |
| super().__init__( | |
| base_learning_rate=base_learning_rate, | |
| cond_stage_config=cond_stage_config, | |
| sequence_gen_length=sequence_gen_length, | |
| sequence_input_key=sequence_input_key, | |
| use_ar_gen_loss=use_ar_gen_loss, | |
| use_warmup=use_warmup, | |
| sequence_input_embed_dim=sequence_input_embed_dim, | |
| batchsize=batchsize, | |
| ) | |
| assert use_gt_mae_output is not None and use_gt_mae_prob is not None | |
| self.always_output_audiomae_gt = always_output_audiomae_gt | |
| self.force_reload_pretrain_avoid_overwrite = ( | |
| force_reload_pretrain_avoid_overwrite | |
| ) | |
| self.pretrained_path = pretrained_path | |
| if self.force_reload_pretrain_avoid_overwrite: | |
| self.is_reload = False | |
| else: | |
| self.is_reload = True | |
| self.load_pretrain_model() | |
| self.prenet = Prenet(in_dim=768, sizes=[768, 768, 768], dropout_rate=0.5) | |
| self.use_gt_mae_output = use_gt_mae_output | |
| self.use_gt_mae_prob = use_gt_mae_prob | |
| self.learnable = learnable | |
| if not learnable: | |
| # Only optimize the GPT2 model | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| self.eval() | |
| def load_pretrain_model(self): | |
| if self.pretrained_path is not None: | |
| print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path) | |
| state_dict = torch.load(self.pretrained_path)["state_dict"] | |
| self.load_state_dict(state_dict) | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| return_dict = self.cfg_uncond(batchsize) | |
| return_dict["crossattn_audiomae_generated"] = [ | |
| return_dict["crossattn_audiomae_pooled"][0], | |
| torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(), | |
| ] | |
| return return_dict | |
| def forward(self, batch): | |
| # The conditional module can return both tensor or dictionaries | |
| # The returned tensor will be corresponding to the cond_stage_key | |
| # The returned dict will have keys that correspond to the cond_stage_key | |
| ret_dict = {} | |
| if self.force_reload_pretrain_avoid_overwrite and not self.is_reload: | |
| self.load_pretrain_model() | |
| self.is_reload = True | |
| self.check_module_param_update() | |
| if self.always_output_audiomae_gt or ( | |
| self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob | |
| ): | |
| cond_dict = self.get_input(batch) | |
| gt_audiomae = self.prenet(cond_dict["crossattn_audiomae_pooled"][0]) | |
| ret_dict["crossattn_audiomae_generated"] = [ | |
| gt_audiomae, | |
| torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(), | |
| ] # Input sequence and mask | |
| else: | |
| print("--------------> Generate!!!!!!!!!!!!") | |
| input_embeds, cond_dict = self.generate(batch) | |
| # input_embeds, cond_dict = self.generate_partial(batch) | |
| input_embeds = self.prenet(input_embeds) | |
| input_embeds_mask = ( | |
| torch.ones((input_embeds.size(0), input_embeds.size(1))) | |
| .to(input_embeds.device) | |
| .float() | |
| ) | |
| ret_dict["crossattn_audiomae_generated"] = [ | |
| input_embeds, | |
| input_embeds_mask, | |
| ] # Input sequence and mask | |
| # If the following two keys are not in cond_stage_key, then they will not be used as condition | |
| for key in cond_dict.keys(): | |
| ret_dict[key] = cond_dict[key] | |
| if self.learnable and self.training: | |
| loss = self.training_step(batch, cond_dict=cond_dict) | |
| ret_dict["noncond_loss_clap2audiomae"] = loss | |
| return ret_dict | |
| class AudioMAEConditionCTPoolRandTFSeparated(nn.Module): | |
| """ | |
| audiomae = AudioMAEConditionCTPool2x2() | |
| data = torch.randn((4, 1024, 128)) | |
| output = audiomae(data) | |
| import ipdb;ipdb.set_trace() | |
| exit(0) | |
| """ | |
| def __init__( | |
| self, | |
| time_pooling_factors=[1, 2, 4, 8], | |
| freq_pooling_factors=[1, 2, 4, 8], | |
| eval_time_pooling=None, | |
| eval_freq_pooling=None, | |
| mask_ratio=0.0, | |
| regularization=False, | |
| no_audiomae_mask=True, | |
| no_audiomae_average=False, | |
| ): | |
| super().__init__() | |
| self.device = None | |
| self.time_pooling_factors = time_pooling_factors | |
| self.freq_pooling_factors = freq_pooling_factors | |
| self.no_audiomae_mask = no_audiomae_mask | |
| self.no_audiomae_average = no_audiomae_average | |
| self.eval_freq_pooling = eval_freq_pooling | |
| self.eval_time_pooling = eval_time_pooling | |
| self.mask_ratio = mask_ratio | |
| self.use_reg = regularization | |
| self.audiomae = Vanilla_AudioMAE() | |
| self.audiomae.eval() | |
| for p in self.audiomae.parameters(): | |
| p.requires_grad = False | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| param = next(self.audiomae.parameters()) | |
| assert param.requires_grad == False | |
| device = param.device | |
| # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) | |
| time_pool, freq_pool = min(self.eval_time_pooling, 64), min( | |
| self.eval_freq_pooling, 8 | |
| ) | |
| # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] | |
| # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] | |
| token_num = int(512 / (time_pool * freq_pool)) | |
| return [ | |
| torch.zeros((batchsize, token_num, 768)).to(device).float(), | |
| torch.ones((batchsize, token_num)).to(device).float(), | |
| ] | |
| def pool(self, representation, time_pool=None, freq_pool=None): | |
| assert representation.size(-1) == 768 | |
| representation = representation[:, 1:, :].transpose(1, 2) | |
| bs, embedding_dim, token_num = representation.size() | |
| representation = representation.reshape(bs, embedding_dim, 64, 8) | |
| if self.training: | |
| if time_pool is None and freq_pool is None: | |
| time_pool = min( | |
| 64, | |
| self.time_pooling_factors[ | |
| np.random.choice(list(range(len(self.time_pooling_factors)))) | |
| ], | |
| ) | |
| freq_pool = min( | |
| 8, | |
| self.freq_pooling_factors[ | |
| np.random.choice(list(range(len(self.freq_pooling_factors)))) | |
| ], | |
| ) | |
| # freq_pool = min(8, time_pool) # TODO here I make some modification. | |
| else: | |
| time_pool, freq_pool = min(self.eval_time_pooling, 64), min( | |
| self.eval_freq_pooling, 8 | |
| ) | |
| self.avgpooling = nn.AvgPool2d( | |
| kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) | |
| ) | |
| self.maxpooling = nn.MaxPool2d( | |
| kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) | |
| ) | |
| pooled = ( | |
| self.avgpooling(representation) + self.maxpooling(representation) | |
| ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] | |
| pooled = pooled.flatten(2).transpose(1, 2) | |
| return pooled # [bs, token_num, embedding_dim] | |
| def regularization(self, x): | |
| assert x.size(-1) == 768 | |
| x = F.normalize(x, p=2, dim=-1) | |
| return x | |
| # Required | |
| def forward(self, batch, time_pool=None, freq_pool=None): | |
| assert batch.size(-2) == 1024 and batch.size(-1) == 128 | |
| if self.device is None: | |
| self.device = batch.device | |
| batch = batch.unsqueeze(1) | |
| with torch.no_grad(): | |
| representation = self.audiomae( | |
| batch, | |
| mask_ratio=self.mask_ratio, | |
| no_mask=self.no_audiomae_mask, | |
| no_average=self.no_audiomae_average, | |
| ) | |
| representation = self.pool(representation, time_pool, freq_pool) | |
| if self.use_reg: | |
| representation = self.regularization(representation) | |
| return [ | |
| representation, | |
| torch.ones((representation.size(0), representation.size(1))) | |
| .to(representation.device) | |
| .float(), | |
| ] | |
| class AudioMAEConditionCTPoolRand(nn.Module): | |
| """ | |
| audiomae = AudioMAEConditionCTPool2x2() | |
| data = torch.randn((4, 1024, 128)) | |
| output = audiomae(data) | |
| import ipdb;ipdb.set_trace() | |
| exit(0) | |
| """ | |
| def __init__( | |
| self, | |
| time_pooling_factors=[1, 2, 4, 8], | |
| freq_pooling_factors=[1, 2, 4, 8], | |
| eval_time_pooling=None, | |
| eval_freq_pooling=None, | |
| mask_ratio=0.0, | |
| regularization=False, | |
| no_audiomae_mask=True, | |
| no_audiomae_average=False, | |
| ): | |
| super().__init__() | |
| self.device = None | |
| self.time_pooling_factors = time_pooling_factors | |
| self.freq_pooling_factors = freq_pooling_factors | |
| self.no_audiomae_mask = no_audiomae_mask | |
| self.no_audiomae_average = no_audiomae_average | |
| self.eval_freq_pooling = eval_freq_pooling | |
| self.eval_time_pooling = eval_time_pooling | |
| self.mask_ratio = mask_ratio | |
| self.use_reg = regularization | |
| self.audiomae = Vanilla_AudioMAE() | |
| self.audiomae.eval() | |
| for p in self.audiomae.parameters(): | |
| p.requires_grad = False | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| param = next(self.audiomae.parameters()) | |
| assert param.requires_grad == False | |
| device = param.device | |
| # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) | |
| time_pool, freq_pool = min(self.eval_time_pooling, 64), min( | |
| self.eval_freq_pooling, 8 | |
| ) | |
| # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] | |
| # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] | |
| token_num = int(512 / (time_pool * freq_pool)) | |
| return [ | |
| torch.zeros((batchsize, token_num, 768)).to(device).float(), | |
| torch.ones((batchsize, token_num)).to(device).float(), | |
| ] | |
| def pool(self, representation, time_pool=None, freq_pool=None): | |
| assert representation.size(-1) == 768 | |
| representation = representation[:, 1:, :].transpose(1, 2) | |
| bs, embedding_dim, token_num = representation.size() | |
| representation = representation.reshape(bs, embedding_dim, 64, 8) | |
| if self.training: | |
| if time_pool is None and freq_pool is None: | |
| time_pool = min( | |
| 64, | |
| self.time_pooling_factors[ | |
| np.random.choice(list(range(len(self.time_pooling_factors)))) | |
| ], | |
| ) | |
| # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] | |
| freq_pool = min(8, time_pool) # TODO here I make some modification. | |
| else: | |
| time_pool, freq_pool = min(self.eval_time_pooling, 64), min( | |
| self.eval_freq_pooling, 8 | |
| ) | |
| self.avgpooling = nn.AvgPool2d( | |
| kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) | |
| ) | |
| self.maxpooling = nn.MaxPool2d( | |
| kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) | |
| ) | |
| pooled = ( | |
| self.avgpooling(representation) + self.maxpooling(representation) | |
| ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] | |
| pooled = pooled.flatten(2).transpose(1, 2) | |
| return pooled # [bs, token_num, embedding_dim] | |
| def regularization(self, x): | |
| assert x.size(-1) == 768 | |
| x = F.normalize(x, p=2, dim=-1) | |
| return x | |
| # Required | |
| def forward(self, batch, time_pool=None, freq_pool=None): | |
| assert batch.size(-2) == 1024 and batch.size(-1) == 128 | |
| if self.device is None: | |
| self.device = batch.device | |
| batch = batch.unsqueeze(1) | |
| with torch.no_grad(): | |
| representation = self.audiomae( | |
| batch, | |
| mask_ratio=self.mask_ratio, | |
| no_mask=self.no_audiomae_mask, | |
| no_average=self.no_audiomae_average, | |
| ) | |
| representation = self.pool(representation, time_pool, freq_pool) | |
| if self.use_reg: | |
| representation = self.regularization(representation) | |
| return [ | |
| representation, | |
| torch.ones((representation.size(0), representation.size(1))) | |
| .to(representation.device) | |
| .float(), | |
| ] | |
| class ConditionalToken(nn.Module): | |
| def __init__(self, embedding_dim): | |
| super(ConditionalToken, self).__init__() | |
| self.embedding_dim = embedding_dim | |
| # Define the conditional tokens as fixed values | |
| self.pooling_factor_tokens = { | |
| 1: torch.Tensor([1.0, 0.0] * (embedding_dim // 2)), | |
| 2: torch.Tensor([0.0, 1.0] * (embedding_dim // 2)), | |
| 4: torch.Tensor([1.0, 1.0] * (embedding_dim // 2)), | |
| 8: torch.Tensor([-1.0, 0.0] * (embedding_dim // 2)), | |
| 16: torch.Tensor([0.0, -1.0] * (embedding_dim // 2)), | |
| 32: torch.Tensor([-1.0, -1.0] * (embedding_dim // 2)), | |
| 64: torch.Tensor([0.0, 0.0] * (embedding_dim // 2)), | |
| } | |
| for p in self.parameters(): | |
| p.requires_grad = False | |
| def forward(self, condition, batchsize): | |
| """ | |
| Returns the conditional token for the given condition. | |
| """ | |
| if condition not in self.pooling_factor_tokens.keys(): | |
| raise ValueError(f"Unsupported condition: {condition}") | |
| batched_token = self.pooling_factor_tokens[condition][None, None].expand( | |
| batchsize, 1, self.embedding_dim | |
| ) | |
| return batched_token | |
| class AudioMAEConditionCTPoolRandV2(nn.Module): | |
| """ | |
| audiomae = AudioMAEConditionCTPool2x2() | |
| data = torch.randn((4, 1024, 128)) | |
| output = audiomae(data) | |
| import ipdb;ipdb.set_trace() | |
| exit(0) | |
| """ | |
| def __init__( | |
| self, | |
| time_pooling_factors=[1, 2, 4, 8], | |
| freq_pooling_factors=[1, 2, 4, 8], | |
| eval_time_pooling=None, | |
| eval_freq_pooling=None, | |
| mask_ratio=0.0, | |
| regularization=False, | |
| no_audiomae_mask=True, | |
| no_audiomae_average=False, | |
| ): | |
| super().__init__() | |
| self.device = None | |
| self.time_pooling_factors = time_pooling_factors | |
| self.freq_pooling_factors = freq_pooling_factors | |
| self.no_audiomae_mask = no_audiomae_mask | |
| self.no_audiomae_average = no_audiomae_average | |
| self.eval_freq_pooling = eval_freq_pooling | |
| self.eval_time_pooling = eval_time_pooling | |
| self.mask_ratio = mask_ratio | |
| self.use_reg = regularization | |
| self.pooling_tokens = ConditionalToken(768) | |
| self.audiomae = Vanilla_AudioMAE() | |
| self.audiomae.eval() | |
| for p in self.audiomae.parameters(): | |
| p.requires_grad = False | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| param = next(self.audiomae.parameters()) | |
| assert param.requires_grad == False | |
| device = param.device | |
| # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors) | |
| time_pool, freq_pool = min(self.eval_time_pooling, 64), min( | |
| self.eval_freq_pooling, 8 | |
| ) | |
| # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))] | |
| # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] | |
| pool_condition_token = self.pooling_tokens(time_pool, batchsize).to(device) | |
| token_num = int(512 / (time_pool * freq_pool)) | |
| rep = torch.zeros((batchsize, token_num, 768)).to(device).float() | |
| rep = torch.cat([rep, pool_condition_token], dim=1) | |
| return [rep, torch.ones((batchsize, token_num + 1)).to(device).float()] | |
| def pool(self, representation, time_pool=None, freq_pool=None): | |
| assert representation.size(-1) == 768 | |
| representation = representation[:, 1:, :].transpose(1, 2) | |
| bs, embedding_dim, token_num = representation.size() | |
| representation = representation.reshape(bs, embedding_dim, 64, 8) | |
| if self.training: | |
| if time_pool is None and freq_pool is None: | |
| time_pool = min( | |
| 64, | |
| self.time_pooling_factors[ | |
| np.random.choice(list(range(len(self.time_pooling_factors)))) | |
| ], | |
| ) | |
| # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))] | |
| freq_pool = min(8, time_pool) # TODO here I make some modification. | |
| else: | |
| time_pool, freq_pool = min(self.eval_time_pooling, 64), min( | |
| self.eval_freq_pooling, 8 | |
| ) | |
| self.avgpooling = nn.AvgPool2d( | |
| kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) | |
| ) | |
| self.maxpooling = nn.MaxPool2d( | |
| kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool) | |
| ) | |
| pooled = ( | |
| self.avgpooling(representation) + self.maxpooling(representation) | |
| ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num] | |
| pooled = pooled.flatten(2).transpose(1, 2) | |
| return pooled, time_pool, freq_pool # [bs, token_num, embedding_dim] | |
| def regularization(self, x): | |
| assert x.size(-1) == 768 | |
| x = F.normalize(x, p=2, dim=-1) | |
| return x | |
| # Required | |
| def forward(self, batch): | |
| assert batch.size(-2) == 1024 and batch.size(-1) == 128 | |
| if self.device is None: | |
| self.device = batch.device | |
| batch = batch.unsqueeze(1) | |
| with torch.no_grad(): | |
| representation = self.audiomae( | |
| batch, | |
| mask_ratio=self.mask_ratio, | |
| no_mask=self.no_audiomae_mask, | |
| no_average=self.no_audiomae_average, | |
| ) | |
| representation, time_pool, freq_pool = self.pool(representation) | |
| if self.use_reg: | |
| representation = self.regularization(representation) | |
| pool_condition_token = self.pooling_tokens( | |
| time_pool, representation.size(0) | |
| ).to(representation.device) | |
| representation = torch.cat([representation, pool_condition_token], dim=1) | |
| return [ | |
| representation, | |
| torch.ones((representation.size(0), representation.size(1))) | |
| .to(representation.device) | |
| .float(), | |
| ] | |
| class BeatDownbeatConditionConcat(nn.Module): | |
| def __init__(self, latent_t_size, latent_f_size): | |
| super().__init__() | |
| self.latent_t_size = latent_t_size | |
| self.latent_f_size = latent_f_size | |
| self.device = None | |
| # Required | |
| def get_unconditional_condition(self, batchsize): | |
| return torch.zeros((batchsize, self.latent_t_size, self.latent_f_size)).to( | |
| self.device | |
| ) | |
| # Required | |
| def forward(self, batch): | |
| if self.device is None: | |
| self.device = batch.device | |
| return batch | |
| class CLAPAudioEmbeddingClassifierFreev2(nn.Module): | |
| def __init__( | |
| self, | |
| pretrained_path, | |
| sampling_rate=16000, | |
| embed_mode="audio", | |
| amodel="HTSAT-base", | |
| unconditional_prob=0.1, | |
| random_mute=False, | |
| max_random_mute_portion=0.5, | |
| training_mode=True, | |
| ): | |
| super().__init__() | |
| self.device = "cpu" | |
| self.precision = "fp32" | |
| self.amodel = amodel # or 'PANN-14' | |
| self.tmodel = "roberta" # the best text encoder in our training | |
| self.enable_fusion = False # False if you do not want to use the fusion model | |
| self.fusion_type = "aff_2d" | |
| self.pretrained = pretrained_path | |
| self.embed_mode = embed_mode | |
| self.embed_mode_orig = embed_mode | |
| self.sampling_rate = sampling_rate | |
| self.unconditional_prob = unconditional_prob | |
| self.random_mute = random_mute | |
| self.tokenize = RobertaTokenizer.from_pretrained(config_data["roberta-base"]) | |
| self.max_random_mute_portion = max_random_mute_portion | |
| self.training_mode = training_mode | |
| self.model, self.model_cfg = create_model( | |
| self.amodel, | |
| self.tmodel, | |
| self.pretrained, | |
| precision=self.precision, | |
| device=self.device, | |
| enable_fusion=self.enable_fusion, | |
| fusion_type=self.fusion_type, | |
| ) | |
| audio_cfg = self.model_cfg["audio_cfg"] | |
| self.mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=audio_cfg["sample_rate"], | |
| n_fft=audio_cfg["window_size"], | |
| win_length=audio_cfg["window_size"], | |
| hop_length=audio_cfg["hop_size"], | |
| center=True, | |
| pad_mode="reflect", | |
| power=2.0, | |
| norm=None, | |
| onesided=True, | |
| n_mels=64, | |
| f_min=audio_cfg["fmin"], | |
| f_max=audio_cfg["fmax"], | |
| ) | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| self.unconditional_token = None | |
| self.model.eval() | |
| def get_unconditional_condition(self, batchsize): | |
| self.unconditional_token = self.model.get_text_embedding( | |
| self.tokenizer(["", ""]) | |
| )[0:1] | |
| return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) | |
| def batch_to_list(self, batch): | |
| ret = [] | |
| for i in range(batch.size(0)): | |
| ret.append(batch[i]) | |
| return ret | |
| def make_decision(self, probability): | |
| if float(torch.rand(1)) < probability: | |
| return True | |
| else: | |
| return False | |
| def random_uniform(self, start, end): | |
| val = torch.rand(1).item() | |
| return start + (end - start) * val | |
| def _random_mute(self, waveform): | |
| # waveform: [bs, t-steps] | |
| t_steps = waveform.size(-1) | |
| for i in range(waveform.size(0)): | |
| mute_size = int( | |
| self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) | |
| ) | |
| mute_start = int(self.random_uniform(0, t_steps - mute_size)) | |
| waveform[i, mute_start : mute_start + mute_size] = 0 | |
| return waveform | |
| def cos_similarity(self, waveform, text): | |
| # waveform: [bs, t_steps] | |
| original_embed_mode = self.embed_mode | |
| with torch.no_grad(): | |
| self.embed_mode = "audio" | |
| audio_emb = self(waveform.cuda()) | |
| self.embed_mode = "text" | |
| text_emb = self(text) | |
| similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) | |
| self.embed_mode = original_embed_mode | |
| return similarity.squeeze() | |
| def build_unconditional_emb(self): | |
| self.unconditional_token = self.model.get_text_embedding( | |
| self.tokenizer(["", ""]) | |
| )[0:1] | |
| def forward(self, batch): | |
| # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 | |
| # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 | |
| if self.model.training == True and not self.training_mode: | |
| print( | |
| "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." | |
| ) | |
| self.model, self.model_cfg = create_model( | |
| self.amodel, | |
| self.tmodel, | |
| self.pretrained, | |
| precision=self.precision, | |
| device="cuda", | |
| enable_fusion=self.enable_fusion, | |
| fusion_type=self.fusion_type, | |
| ) | |
| for p in self.model.parameters(): | |
| p.requires_grad = False | |
| self.model.eval() | |
| if self.unconditional_token is None: | |
| self.build_unconditional_emb() | |
| # if(self.training_mode): | |
| # assert self.model.training == True | |
| # else: | |
| # assert self.model.training == False | |
| # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode | |
| if self.embed_mode == "audio": | |
| if not self.training: | |
| print("INFO: clap model calculate the audio embedding as condition") | |
| with torch.no_grad(): | |
| # assert ( | |
| # self.sampling_rate == 16000 | |
| # ), "We only support 16000 sampling rate" | |
| # if self.random_mute: | |
| # batch = self._random_mute(batch) | |
| # batch: [bs, 1, t-samples] | |
| if self.sampling_rate != 48000: | |
| batch = torchaudio.functional.resample( | |
| batch, orig_freq=self.sampling_rate, new_freq=48000 | |
| ) | |
| audio_data = batch.squeeze(1) | |
| mel = self.mel_transform(audio_data) | |
| audio_dict = get_audio_features( | |
| audio_data, | |
| mel, | |
| 480000, | |
| data_truncating="fusion", | |
| data_filling="repeatpad", | |
| audio_cfg=self.model_cfg["audio_cfg"], | |
| ) | |
| # [bs, 512] | |
| embed = self.model.get_audio_embedding(audio_dict) | |
| elif self.embed_mode == "text": | |
| with torch.no_grad(): | |
| # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode | |
| text_data = self.tokenizer(batch) | |
| if isinstance(batch, str) or ( | |
| isinstance(batch, list) and len(batch) == 1 | |
| ): | |
| for key in text_data.keys(): | |
| text_data[key] = text_data[key].unsqueeze(0) | |
| embed = self.model.get_text_embedding(text_data) | |
| embed = embed.unsqueeze(1) | |
| for i in range(embed.size(0)): | |
| if self.make_decision(self.unconditional_prob): | |
| embed[i] = self.unconditional_token | |
| # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch) | |
| return embed.detach() | |
| def tokenizer(self, text): | |
| result = self.tokenize( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt", | |
| ) | |
| return {k: v.squeeze(0) for k, v in result.items()} | |
| if __name__ == "__main__": | |
| model = CLAPAudioEmbeddingClassifierFreev2( | |
| pretrained_path="/mnt/bn/lqhaoheliu/exps/checkpoints/audioldm/ckpt/CLAP.pt", | |
| embed_mode="text", | |
| amodel="HTSAT-tiny", | |
| ) | |
| # data = torch.randn((6, 1, int(16000*10.24))) | |
| data = ["text", "text"] | |
| res = model(data) | |
| import ipdb | |
| ipdb.set_trace() | |