| | from utils import *
|
| | from modules import *
|
| | import os, sys
|
| | import numpy as np
|
| | from tqdm import tqdm
|
| | import random
|
| | import torch
|
| | from torch import nn
|
| | from config import CFG
|
| | from dataset import *
|
| | import torch.utils.data
|
| | import copy, json, pickle
|
| | import itertools as it
|
| | import loss
|
| |
|
| | loss_func = loss.infoNCE_loss2
|
| |
|
| | def make_next_record_dir(basedir, prefix=''):
|
| | path = '%s/%%s001/' %basedir
|
| | n = 2
|
| | while os.path.exists(path %prefix):
|
| | path = '%s/%%s%.3d/' %(basedir, n)
|
| | n += 1
|
| |
|
| | pth = path %prefix
|
| | os.makedirs(pth)
|
| | return pth
|
| |
|
| | def setup_seed(seed):
|
| | torch.manual_seed(seed)
|
| | torch.cuda.manual_seed(seed)
|
| | np.random.seed(seed)
|
| | random.seed(seed)
|
| | torch.backends.cudnn.deterministic = True
|
| |
|
| | def my_collate(batch):
|
| | batch = list(filter(lambda x:(x is not None), batch))
|
| | msbinl, molfpl, molfml, vl, al, msl = [], [], [], [], [], []
|
| | bat = {}
|
| |
|
| | for b in batch:
|
| | if 'ms_bins' in b:
|
| | msbinl.append(b['ms_bins'])
|
| | if 'mol_fps' in b:
|
| | molfpl.append(b['mol_fps'])
|
| | if 'mol_fmvec' in b:
|
| | molfml.append(b['mol_fmvec'])
|
| | if 'V' in b:
|
| | vl.append(b['V'])
|
| | if 'A' in b:
|
| | al.append(b['A'])
|
| | if 'mol_size' in b:
|
| | msl.append(b['mol_size'])
|
| |
|
| | if msbinl:
|
| | bat['ms_bins'] = torch.stack(msbinl)
|
| | if molfpl:
|
| | bat['mol_fps'] = torch.stack(molfpl)
|
| | if molfml:
|
| | bat['mol_fmvec'] = torch.stack(molfml)
|
| | if vl and al and msl:
|
| | max_n = max(map(lambda x:x.shape[0], vl))
|
| | vl1, al1 = [], []
|
| | for v in vl:
|
| | vl1.append(pad_V(v, max_n))
|
| | for a in al:
|
| | al1.append(pad_A(a, max_n))
|
| |
|
| | bat['V'] = torch.stack(vl1)
|
| | bat['A'] = torch.stack(al1)
|
| | bat['mol_size'] = torch.cat(msl, dim=0)
|
| |
|
| |
|
| | return bat
|
| |
|
| | def make_train_valid(data, valid_ratio, seed=1234):
|
| | idxs = np.arange(len(data))
|
| | np.random.seed(seed)
|
| | np.random.shuffle(idxs)
|
| |
|
| | lenval = int(valid_ratio*len(data))
|
| |
|
| | valid_set = [ data[i] for i in idxs[:lenval] ]
|
| | train_set = [ data[i] for i in idxs[lenval:] ]
|
| |
|
| | return train_set, valid_set
|
| |
|
| | def build_loaders(inp, mode, cfg, num_workers):
|
| | if type(inp[0]) is dict:
|
| | dataset = Dataset(inp, cfg)
|
| | else:
|
| | dataset = PathDataset(inp, cfg)
|
| | dataloader = torch.utils.data.DataLoader(
|
| | dataset,
|
| | batch_size=cfg.batch_size,
|
| | num_workers=num_workers,
|
| | shuffle=True if mode == "train" else False,
|
| | collate_fn=my_collate
|
| | )
|
| | return dataloader
|
| |
|
| | def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
|
| | model.train()
|
| | loss_meter = AvgMeter()
|
| | tqdm_object = tqdm(train_loader, total=len(train_loader))
|
| | total_cos_sim = 0
|
| |
|
| | for batch in tqdm_object:
|
| | for k, v in batch.items():
|
| | batch[k] = v.to(CFG.device)
|
| |
|
| | optimizer.zero_grad()
|
| |
|
| | mol_features, ms_features = model(batch)
|
| |
|
| | loss = loss_func(mol_features, ms_features)
|
| |
|
| | loss.backward()
|
| | optimizer.step()
|
| |
|
| | with torch.no_grad():
|
| | cos_sim = F.cosine_similarity(
|
| | mol_features.detach(),
|
| | ms_features.detach()
|
| | ).mean().item()
|
| | total_cos_sim += cos_sim
|
| |
|
| | if step == "batch":
|
| | lr_scheduler.step()
|
| |
|
| | count = batch["ms_bins"].size(0)
|
| | loss_meter.update(loss.item(), count)
|
| |
|
| | tqdm_object.set_postfix(train_loss=loss_meter.avg, train_cos_sim=round(cos_sim, 4), lr=get_lr(optimizer))
|
| |
|
| | del mol_features, ms_features, loss, cos_sim
|
| |
|
| | for k in list(batch.keys()):
|
| | del batch[k]
|
| | del batch
|
| |
|
| | return loss_meter, total_cos_sim / len(train_loader)
|
| |
|
| | def valid_epoch(model, valid_loader):
|
| | model.eval()
|
| | loss_meter = AvgMeter()
|
| | total_cos_sim = 0
|
| |
|
| | with torch.no_grad():
|
| | tqdm_object = tqdm(valid_loader, total=len(valid_loader))
|
| | for batch in tqdm_object:
|
| | for k, v in batch.items():
|
| | batch[k] = v.to(CFG.device)
|
| |
|
| | mol_features, ms_features = model(batch)
|
| |
|
| | loss = loss_func(mol_features, ms_features)
|
| |
|
| | count = batch["ms_bins"].size(0)
|
| | loss_meter.update(loss.item(), count)
|
| | cos_sim = F.cosine_similarity(mol_features.detach(), ms_features.detach()).mean().item()
|
| | total_cos_sim += cos_sim
|
| |
|
| | tqdm_object.set_postfix(valid_loss=loss_meter.avg, valid_cos_sim=round(cos_sim, 4))
|
| |
|
| | del mol_features, ms_features, loss, cos_sim
|
| |
|
| | for k in list(batch.keys()):
|
| | del batch[k]
|
| | del batch
|
| |
|
| | return loss_meter, total_cos_sim / len(valid_loader)
|
| |
|
| | def main(data, cfg=CFG, savedir='data/train', model_path=None, ratio=1):
|
| | setup_seed(cfg.seed)
|
| |
|
| | train_set, valid_set = make_train_valid(data, valid_ratio=cfg.valid_ratio, seed=cfg.seed)
|
| |
|
| | log_file = f'{savedir}/trainlog.txt'
|
| |
|
| | n = len(train_set)
|
| | if ratio < 1:
|
| | train_set = random.sample(train_set, int(n*ratio))
|
| | print(f'Ratio {ratio}, lenall {n}, newtrainset {len(train_set)}')
|
| |
|
| | train_loader = build_loaders(train_set, "train", cfg, 1)
|
| | valid_loader = build_loaders(valid_set, "valid", cfg, 1)
|
| |
|
| | step = "epoch"
|
| |
|
| | best_loss = float('inf')
|
| | best_model_fn = ''
|
| | best_model_fns = []
|
| |
|
| | model = FragSimiModel(cfg).to(cfg.device)
|
| |
|
| | print(model)
|
| |
|
| | optimizer = torch.optim.AdamW(
|
| | model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
|
| | )
|
| |
|
| | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| | optimizer, mode="min", patience=cfg.patience, factor=cfg.factor
|
| | )
|
| |
|
| |
|
| | if model_path and os.path.exists(model_path):
|
| | print(f"Loading model from {model_path}")
|
| | checkpoint = torch.load(model_path, map_location=cfg.device)
|
| | model.load_state_dict(checkpoint['state_dict'])
|
| |
|
| | '''if 'optimizer' in checkpoint:
|
| | optimizer.load_state_dict(checkpoint['optimizer'])
|
| | print("Loaded optimizer state")'''
|
| |
|
| | print(f"Resuming training")
|
| | del checkpoint
|
| |
|
| |
|
| | with open(log_file, 'a', encoding='utf8') as f:
|
| | f.write(f'Start training:\n')
|
| | f.write(f'Data path: {cfg.dataset_path}, valid ratio: {cfg.valid_ratio}\n')
|
| | if model_path:
|
| | f.write(f'Resuming from: {model_path}\n')
|
| | print(model, file=f)
|
| | f.write(f'\n')
|
| |
|
| | for epoch in range(cfg.epochs):
|
| | print(f"Epoch: {epoch + 1}/{cfg.epochs}")
|
| | train_loss, t_cos_sim = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
|
| | valid_loss, v_cos_sim = valid_epoch(model, valid_loader)
|
| |
|
| | txt = f"Train Loss: {train_loss.avg:.4f} | Val Loss: {valid_loss.avg:.4f} | Train cos sim: {t_cos_sim:.4f} | Val cos sim: {v_cos_sim:.4f}"
|
| | print(txt)
|
| | open(log_file, 'a').write(f"Epoch {epoch + 1}/{cfg.epochs}: {txt}\n")
|
| |
|
| | if True:
|
| | best_loss = valid_loss.avg
|
| | best_model_fn = f"{savedir}/model-tloss{round(train_loss.avg, 3)}-vloss{round(valid_loss.avg, 3)}-tcos{round(t_cos_sim, 3)}-vcos{round(v_cos_sim, 3)}-epoch{epoch}.pth"
|
| | best_model_fn_base = best_model_fn.replace('.pth', '')
|
| | n = 1
|
| | while os.path.exists(best_model_fn):
|
| | best_model_fn = best_model_fn_base + f'-{n}.pth'
|
| | n += 1
|
| |
|
| | best_model_fns.append(best_model_fn)
|
| |
|
| | torch.save({'epoch': epoch,
|
| | 'state_dict': model.state_dict(),
|
| | 'optimizer': optimizer.state_dict(),
|
| | 'config': dict(CFG),
|
| | 'train_loss': train_loss.avg,
|
| | 'valid_loss': valid_loss.avg,
|
| | 'train_cos_sim': t_cos_sim,
|
| | 'val_cos_sim': v_cos_sim
|
| | }, best_model_fn)
|
| |
|
| | print("Saved new best model!")
|
| |
|
| | best_model_fnl = []
|
| | for fn in best_model_fns:
|
| | if os.path.exists(fn):
|
| | best_model_fnl.append(fn)
|
| |
|
| | for fn in best_model_fnl[:-cfg.keep_best_models_num]:
|
| | os.remove(fn)
|
| |
|
| | best_model_fnl = best_model_fnl[-cfg.keep_best_models_num:]
|
| |
|
| | print(best_model_fnl, best_loss)
|
| | return best_model_fnl, best_loss
|
| |
|
| | if __name__ == "__main__":
|
| | import pickle
|
| | from tqdm import tqdm
|
| | try:
|
| | conffn = sys.argv[1]
|
| | if conffn.endswith('.json'):
|
| | CFG.load(sys.argv[1])
|
| | elif conffn.endswith('.pth'):
|
| | dpath = CFG.dataset_path
|
| | d = torch.load(conffn)
|
| | CFG.load(d['config'])
|
| | CFG.dataset_path = dpath
|
| | print('Use config from', conffn)
|
| | except:
|
| | pass
|
| |
|
| | try:
|
| | savedir = sys.argv[2]
|
| | except:
|
| | savedir = 'data/'
|
| |
|
| | os.system('mkdir -p %s' %savedir)
|
| |
|
| | try:
|
| | prev_model_pth = sys.argv[3]
|
| | except:
|
| | prev_model_pth = None
|
| |
|
| | print(CFG)
|
| |
|
| | if os.path.isdir(CFG.dataset_path):
|
| | data = [os.path.join(CFG.dataset_path, i) for i in os.listdir(CFG.dataset_path) if i.endswith('mgf')]
|
| | elif os.path.isfile(CFG.dataset_path):
|
| | if CFG.dataset_path.endswith('.pkl'):
|
| | print(f'loading data from {CFG.dataset_path} ...')
|
| | data = pickle.load(open(CFG.dataset_path, 'rb'))
|
| | else:
|
| | data = json.load(open(CFG.dataset_path))
|
| | pklfn = CFG.dataset_path.replace('.json', '.pkl')
|
| | if not os.path.exists(pklfn):
|
| | pickle.dump(data, open(pklfn, 'wb'))
|
| |
|
| | subdir = make_next_record_dir(savedir, f'train-neg-')
|
| | os.system(f'cp -a *py {subdir}; cp -a GNN {subdir}')
|
| | CFG.save(f'{subdir}/config.json')
|
| |
|
| | modelfnl, _ = main(data, CFG, subdir, prev_model_pth)
|
| |
|