Spaces:
Build error
Build error
| import sys | |
| # sys.path.append("src") | |
| import shutil | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| import argparse | |
| import yaml | |
| import torch | |
| from tqdm import tqdm | |
| from pytorch_lightning.strategies.ddp import DDPStrategy | |
| from qa_mdt.audioldm_train.modules.latent_diffusion.ddpm import LatentDiffusion | |
| from torch.utils.data import WeightedRandomSampler | |
| from torch.utils.data import DataLoader | |
| from pytorch_lightning import Trainer, seed_everything | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from pytorch_lightning.loggers import WandbLogger | |
| from qa_mdt.audioldm_train.utilities.tools import ( | |
| listdir_nohidden, | |
| get_restore_step, | |
| copy_test_subset_data, | |
| ) | |
| import wandb | |
| from qa_mdt.audioldm_train.utilities.model_util import instantiate_from_config | |
| import logging | |
| logging.basicConfig(level=logging.WARNING) | |
| def convert_path(path): | |
| parts = path.decode().split("/")[-4:] | |
| base = "" | |
| result = "/".join(parts) | |
| def print_on_rank0(msg): | |
| if torch.distributed.get_rank() == 0: | |
| print(msg) | |
| def main(configs, config_yaml_path, exp_group_name, exp_name, perform_validation): | |
| print("MAIN START") | |
| # cpth = "/train20/intern/permanent/changli7/dataset_ptm/test_dataset/dataset/audioset/zip_audios/unbalanced_train_segments/unbalanced_train_segments_part9/Y7fmOlUlwoNg.wav" | |
| # convert_path(cpth) | |
| if "seed" in configs.keys(): | |
| seed_everything(configs["seed"]) | |
| else: | |
| print("SEED EVERYTHING TO 0") | |
| seed_everything(1234) | |
| if "precision" in configs.keys(): | |
| torch.set_float32_matmul_precision( | |
| configs["precision"] | |
| ) # highest, high, medium | |
| log_path = configs["log_directory"] | |
| batch_size = configs["model"]["params"]["batchsize"] | |
| train_lmdb_path = configs["train_path"]["train_lmdb_path"] | |
| train_key_path = [_ + '/data_key.key' for _ in train_lmdb_path] | |
| val_lmdb_path = configs["val_path"]["val_lmdb_path"] | |
| val_key_path = configs["val_path"]["val_key_path"] | |
| #try: | |
| mos_path = configs["mos_path"] | |
| from qa_mdt.audioldm_train.utilities.data.hhhh import AudioDataset | |
| dataset = AudioDataset(config=configs, lmdb_path=train_lmdb_path, key_path=train_key_path, mos_path=mos_path) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=8, | |
| pin_memory=True, | |
| shuffle=True, | |
| ) | |
| print( | |
| "The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s" | |
| % (len(dataset), len(loader), batch_size) | |
| ) | |
| try: | |
| val_dataset = AudioDataset(config=configs, lmdb_path=val_lmdb_path, key_path=val_key_path, mos_path=mos_path) | |
| except: | |
| val_dataset = AudioDataset(config=configs, lmdb_path=val_lmdb_path, key_path=val_key_path) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=8, | |
| ) | |
| # Copy test data | |
| import os | |
| test_data_subset_folder = os.path.join( | |
| os.path.dirname(configs["log_directory"]), | |
| "testset_data", | |
| "tmp", | |
| ) | |
| os.makedirs(test_data_subset_folder, exist_ok=True) | |
| # copy to test: | |
| # import pdb | |
| # pdb.set_trace() | |
| # for i in range(len(val_dataset.keys)): | |
| # key_tmp = val_dataset.keys[i].decode() | |
| # cmd = "cp {} {}".format(key_tmp, os.path.join(test_data_subset_folder)) | |
| # os.system(cmd) | |
| try: | |
| config_reload_from_ckpt = configs["reload_from_ckpt"] | |
| except: | |
| config_reload_from_ckpt = None | |
| try: | |
| limit_val_batches = configs["step"]["limit_val_batches"] | |
| except: | |
| limit_val_batches = None | |
| validation_every_n_epochs = configs["step"]["validation_every_n_epochs"] | |
| save_checkpoint_every_n_steps = configs["step"]["save_checkpoint_every_n_steps"] | |
| max_steps = configs["step"]["max_steps"] | |
| save_top_k = configs["step"]["save_top_k"] | |
| checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") | |
| wandb_path = os.path.join(log_path, exp_group_name, exp_name) | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath=checkpoint_path, | |
| monitor="global_step", | |
| mode="max", | |
| filename="checkpoint-fad-{val/frechet_inception_distance:.2f}-global_step={global_step:.0f}", | |
| every_n_train_steps=save_checkpoint_every_n_steps, | |
| save_top_k=save_top_k, | |
| auto_insert_metric_name=False, | |
| save_last=False, | |
| ) | |
| os.makedirs(checkpoint_path, exist_ok=True) | |
| # shutil.copy(config_yaml_path, wandb_path) | |
| if len(os.listdir(checkpoint_path)) > 0: | |
| print("Load checkpoint from path: %s" % checkpoint_path) | |
| restore_step, n_step = get_restore_step(checkpoint_path) | |
| resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) | |
| print("Resume from checkpoint", resume_from_checkpoint) | |
| elif config_reload_from_ckpt is not None: | |
| resume_from_checkpoint = config_reload_from_ckpt | |
| print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) | |
| else: | |
| print("Train from scratch") | |
| resume_from_checkpoint = None | |
| devices = torch.cuda.device_count() | |
| latent_diffusion = instantiate_from_config(configs["model"]) | |
| latent_diffusion.set_log_dir(log_path, exp_group_name, exp_name) | |
| wandb_logger = WandbLogger( | |
| save_dir=wandb_path, | |
| project=configs["project"], | |
| config=configs, | |
| name="%s/%s" % (exp_group_name, exp_name), | |
| ) | |
| latent_diffusion.test_data_subset_path = test_data_subset_folder | |
| print("==> Save checkpoint every %s steps" % save_checkpoint_every_n_steps) | |
| print("==> Perform validation every %s epochs" % validation_every_n_epochs) | |
| trainer = Trainer( | |
| accelerator="auto", | |
| devices="auto", | |
| logger=wandb_logger, | |
| max_steps=max_steps, | |
| num_sanity_val_steps=1, | |
| limit_val_batches=limit_val_batches, | |
| check_val_every_n_epoch=validation_every_n_epochs, | |
| strategy=DDPStrategy(find_unused_parameters=True), | |
| gradient_clip_val=2.0,callbacks=[checkpoint_callback],num_nodes=1, | |
| ) | |
| trainer.fit(latent_diffusion, loader, val_loader, ckpt_path=resume_from_checkpoint) | |
| ################################################################################################################ | |
| # if(resume_from_checkpoint is not None): | |
| # ckpt = torch.load(resume_from_checkpoint)["state_dict"] | |
| # key_not_in_model_state_dict = [] | |
| # size_mismatch_keys = [] | |
| # state_dict = latent_diffusion.state_dict() | |
| # print("Filtering key for reloading:", resume_from_checkpoint) | |
| # print("State dict key size:", len(list(state_dict.keys())), len(list(ckpt.keys()))) | |
| # for key in tqdm(list(ckpt.keys())): | |
| # if(key not in state_dict.keys()): | |
| # key_not_in_model_state_dict.append(key) | |
| # del ckpt[key] | |
| # continue | |
| # if(state_dict[key].size() != ckpt[key].size()): | |
| # del ckpt[key] | |
| # size_mismatch_keys.append(key) | |
| # if(len(key_not_in_model_state_dict) != 0 or len(size_mismatch_keys) != 0): | |
| # print("⛳", end=" ") | |
| # print("==> Warning: The following key in the checkpoint is not presented in the model:", key_not_in_model_state_dict) | |
| # print("==> Warning: These keys have different size between checkpoint and current model: ", size_mismatch_keys) | |
| # latent_diffusion.load_state_dict(ckpt, strict=False) | |
| # if(perform_validation): | |
| # trainer.validate(latent_diffusion, val_loader) | |
| # trainer.fit(latent_diffusion, loader, val_loader) | |
| ################################################################################################################ | |
| if __name__ == "__main__": | |
| print("ok") | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-c", | |
| "--config_yaml", | |
| type=str, | |
| required=False, | |
| help="path to config .yaml file", | |
| ) | |
| parser.add_argument("--val", action="store_true") | |
| args = parser.parse_args() | |
| perform_validation = args.val | |
| assert torch.cuda.is_available(), "CUDA is not available" | |
| config_yaml = args.config_yaml | |
| exp_name = os.path.basename(config_yaml.split(".")[0]) | |
| exp_group_name = os.path.basename(os.path.dirname(config_yaml)) | |
| config_yaml_path = os.path.join(config_yaml) | |
| config_yaml = yaml.load(open(config_yaml_path, "r"), Loader=yaml.FullLoader) | |
| if perform_validation: | |
| config_yaml["model"]["params"]["cond_stage_config"][ | |
| "crossattn_audiomae_generated" | |
| ]["params"]["use_gt_mae_output"] = False | |
| config_yaml["step"]["limit_val_batches"] = None | |
| main(config_yaml, config_yaml_path, exp_group_name, exp_name, perform_validation) | |