|
|
from typing import Any |
|
|
import torch |
|
|
import time |
|
|
import os |
|
|
import random |
|
|
import wandb |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import logging |
|
|
from pytorch_lightning import LightningModule |
|
|
from analysis import metrics |
|
|
from analysis import utils as au |
|
|
from models.flow_model import FlowModel |
|
|
from models import utils as mu |
|
|
from data.interpolant import Interpolant |
|
|
from data import utils as du |
|
|
from data import all_atom |
|
|
from data import so3_utils |
|
|
from experiments import utils as eu |
|
|
from data import residue_constants |
|
|
from data.residue_constants import order2restype_with_mask |
|
|
from pytorch_lightning.loggers.wandb import WandbLogger |
|
|
from openfold.utils.exponential_moving_average import ExponentialMovingAverage as EMA |
|
|
from openfold.utils.tensor_utils import ( |
|
|
tensor_tree_map, |
|
|
) |
|
|
|
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
|
|
|
class FlowModule(LightningModule): |
|
|
|
|
|
def __init__(self, cfg, folding_cfg=None): |
|
|
super().__init__() |
|
|
self._print_logger = logging.getLogger(__name__) |
|
|
self._exp_cfg = cfg.experiment |
|
|
self._model_cfg = cfg.model |
|
|
self._data_cfg = cfg.data |
|
|
self._interpolant_cfg = cfg.interpolant |
|
|
|
|
|
|
|
|
self.model = FlowModel(cfg.model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.interpolant = Interpolant(cfg.interpolant) |
|
|
|
|
|
self._sample_write_dir = self._exp_cfg.checkpointer.dirpath |
|
|
os.makedirs(self._sample_write_dir, exist_ok=True) |
|
|
|
|
|
self.validation_epoch_metrics = [] |
|
|
self.validation_epoch_samples = [] |
|
|
self.save_hyperparameters() |
|
|
|
|
|
def on_train_start(self): |
|
|
self._epoch_start_time = time.time() |
|
|
|
|
|
def on_train_epoch_end(self): |
|
|
epoch_time = (time.time() - self._epoch_start_time) / 60.0 |
|
|
self.log( |
|
|
'train/epoch_time_minutes', |
|
|
epoch_time, |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=False |
|
|
) |
|
|
self._epoch_start_time = time.time() |
|
|
|
|
|
def model_step(self, noisy_batch: Any, use_mask_aatype = True, eps=1e-8): |
|
|
training_cfg = self._exp_cfg.training |
|
|
loss_mask = noisy_batch['res_mask'] |
|
|
if training_cfg.min_plddt_mask is not None: |
|
|
plddt_mask = noisy_batch['res_plddt'] > training_cfg.min_plddt_mask |
|
|
loss_mask *= plddt_mask |
|
|
num_batch, num_res = loss_mask.shape |
|
|
|
|
|
|
|
|
gt_trans_1 = noisy_batch['trans_1'] |
|
|
gt_rotmats_1 = noisy_batch['rotmats_1'] |
|
|
rotmats_t = noisy_batch['rotmats_t'] |
|
|
gt_rot_vf = so3_utils.calc_rot_vf( |
|
|
rotmats_t, gt_rotmats_1.type(torch.float32)) |
|
|
|
|
|
gt_bb_atoms_total_result = all_atom.to_atom37(gt_trans_1, gt_rotmats_1, aatype=noisy_batch['aatype'], get_mask=True) |
|
|
gt_bb_atoms = gt_bb_atoms_total_result[0][:, :, :3, :] |
|
|
atom37_mask = gt_bb_atoms_total_result[1] |
|
|
|
|
|
gt_atoms = noisy_batch['all_atom_positions'] |
|
|
|
|
|
|
|
|
t = noisy_batch['t'] |
|
|
norm_scale = 1 - torch.min( |
|
|
t[..., None], torch.tensor(training_cfg.t_normalize_clip)) |
|
|
|
|
|
|
|
|
model_output = self.model(noisy_batch, use_mask_aatype=use_mask_aatype) |
|
|
pred_trans_1 = model_output['pred_trans'] |
|
|
pred_rotmats_1 = model_output['pred_rotmats'] |
|
|
pred_torsions_with_CB = model_output['pred_torsions_with_CB'] |
|
|
pred_rots_vf = so3_utils.calc_rot_vf(rotmats_t, pred_rotmats_1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trans_error = (gt_trans_1 - pred_trans_1) / norm_scale |
|
|
loss_denom = torch.sum(loss_mask, dim=-1) * 3 |
|
|
trans_loss = training_cfg.translation_loss_weight * torch.sum( |
|
|
trans_error ** 2 * loss_mask[..., None], |
|
|
dim=(-1, -2) |
|
|
) / loss_denom |
|
|
|
|
|
|
|
|
|
|
|
rots_vf_error = (gt_rot_vf - pred_rots_vf) / norm_scale |
|
|
rots_vf_loss = training_cfg.rotation_loss_weights * torch.sum( |
|
|
rots_vf_error ** 2 * loss_mask[..., None], |
|
|
dim=(-1, -2) |
|
|
) / loss_denom |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torsion_angles, torsion_mask = all_atom.prot_to_torsion_angles(noisy_batch['aatype'], gt_atoms, atom37_mask) |
|
|
|
|
|
|
|
|
|
|
|
torsion_loss = torch.sum( |
|
|
(torsion_angles - pred_torsions_with_CB[:,:,1:,:]) ** 2 * torsion_mask[..., None], |
|
|
dim=(-1, -2, -3) |
|
|
) / torch.sum(torsion_mask, dim=(-1, -2)) |
|
|
|
|
|
|
|
|
|
|
|
pred_atoms, atoms_mask = all_atom.to_atom37(pred_trans_1, pred_rotmats_1, aatype = noisy_batch['aatype'], torsions_with_CB = pred_torsions_with_CB, get_mask = True) |
|
|
atoms_mask = atoms_mask * loss_mask[...,None] |
|
|
pred_atoms_flat, gt_atoms_flat, _ = du.batch_align_structures( |
|
|
pred_atoms.reshape(num_batch, -1, 3), gt_atoms.reshape(num_batch, -1, 3), mask=atoms_mask.reshape(num_batch, -1) |
|
|
) |
|
|
gt_atoms = gt_atoms_flat * training_cfg.bb_atom_scale / norm_scale |
|
|
pred_atoms = pred_atoms_flat * training_cfg.bb_atom_scale / norm_scale |
|
|
|
|
|
bb_atom_loss = torch.sum( |
|
|
(gt_atoms - pred_atoms) ** 2, |
|
|
dim=(-1, -2) |
|
|
) / torch.sum(atoms_mask, dim=(-1, -2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_bb_atoms = all_atom.to_atom37(pred_trans_1, pred_rotmats_1)[:, :, :3] |
|
|
gt_bb_atoms_pair = gt_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None] |
|
|
pred_bb_atoms_pair = pred_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None] |
|
|
gt_flat_atoms = gt_bb_atoms_pair.reshape([num_batch, num_res*3, 3]) |
|
|
gt_pair_dists = torch.linalg.norm( |
|
|
gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1) |
|
|
pred_flat_atoms = pred_bb_atoms_pair.reshape([num_batch, num_res*3, 3]) |
|
|
pred_pair_dists = torch.linalg.norm( |
|
|
pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1) |
|
|
|
|
|
flat_loss_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3)) |
|
|
flat_loss_mask = flat_loss_mask.reshape([num_batch, num_res*3]) |
|
|
flat_res_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3)) |
|
|
flat_res_mask = flat_res_mask.reshape([num_batch, num_res*3]) |
|
|
|
|
|
gt_pair_dists = gt_pair_dists * flat_loss_mask[..., None] |
|
|
pred_pair_dists = pred_pair_dists * flat_loss_mask[..., None] |
|
|
pair_dist_mask = flat_loss_mask[..., None] * flat_res_mask[:, None, :] |
|
|
|
|
|
pair_dist_mat_loss = torch.sum( |
|
|
(gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask, |
|
|
dim=(1, 2)) / (torch.sum(pair_dist_mask, dim=(1, 2)) - num_res) |
|
|
|
|
|
|
|
|
|
|
|
pred_bb_atoms = all_atom.to_atom37(pred_trans_1, pred_rotmats_1)[:, :, :3] |
|
|
gt_bb_atoms_seq = gt_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None] |
|
|
pred_bb_atoms_seq = pred_bb_atoms * training_cfg.bb_atom_scale / norm_scale[..., None] |
|
|
gt_flat_atoms = gt_bb_atoms_seq.reshape([num_batch, num_res*3, 3]) |
|
|
gt_seq_dists = torch.linalg.norm( |
|
|
gt_flat_atoms[:, :-3, :] - gt_flat_atoms[:, 3:, :], dim=-1) |
|
|
pred_flat_atoms = pred_bb_atoms_seq.reshape([num_batch, num_res*3, 3]) |
|
|
pred_seq_dists = torch.linalg.norm( |
|
|
pred_flat_atoms[:, :-3, :] - pred_flat_atoms[:, 3:, :], dim=-1) |
|
|
|
|
|
flat_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3)) |
|
|
flat_mask = (flat_mask[:,1:,:] * flat_mask[:,:-1,:]).reshape([num_batch, 3*(num_res-1)]) |
|
|
|
|
|
gt_seq_dists = gt_seq_dists * flat_mask |
|
|
pred_seq_dists = pred_seq_dists * flat_mask |
|
|
|
|
|
seq_dist_mat_loss = torch.sum( |
|
|
(gt_seq_dists - pred_seq_dists)**2 * flat_mask, |
|
|
dim=(1)) / (torch.sum(flat_mask, dim=(1))) |
|
|
|
|
|
dist_mat_loss = pair_dist_mat_loss + seq_dist_mat_loss |
|
|
|
|
|
se3_vf_loss = trans_loss + rots_vf_loss |
|
|
auxiliary_loss = (bb_atom_loss + dist_mat_loss) * ( |
|
|
t[:, 0] > training_cfg.aux_loss_t_pass |
|
|
) |
|
|
auxiliary_loss *= self._exp_cfg.training.aux_loss_weight |
|
|
se3_vf_loss += auxiliary_loss |
|
|
|
|
|
se3_vf_loss += torsion_loss |
|
|
|
|
|
return { |
|
|
"trans_loss": trans_loss, |
|
|
"rots_vf_loss": rots_vf_loss, |
|
|
"bb_atom_loss": bb_atom_loss, |
|
|
"dist_mat_loss": dist_mat_loss, |
|
|
"auxiliary_loss": auxiliary_loss, |
|
|
"se3_vf_loss": se3_vf_loss, |
|
|
|
|
|
} |
|
|
|
|
|
def validation_step(self, batch: Any, batch_idx: int): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res_mask = batch['res_mask'] |
|
|
self.interpolant.set_device(res_mask.device) |
|
|
num_batch, num_res = res_mask.shape |
|
|
|
|
|
self.model.eval() |
|
|
samples = self.interpolant.sample( |
|
|
batch, |
|
|
self.model, |
|
|
)[0][-1].numpy() |
|
|
self.model.train() |
|
|
|
|
|
batch_metrics = [] |
|
|
for i in range(num_batch): |
|
|
|
|
|
|
|
|
final_pos = samples[i] |
|
|
|
|
|
|
|
|
ca_ca_metrics = metrics.calc_ca_ca_metrics(final_pos[:, residue_constants.atom_order['CA']]) |
|
|
|
|
|
|
|
|
rmsd = metrics.calc_aligned_rmsd( |
|
|
final_pos[:, :3].reshape(-1,3), batch['all_atom_positions'][i].cpu().numpy()[:,:3].reshape(-1,3)) |
|
|
seq_string = ''.join([order2restype_with_mask[int(aa)] for aa in batch['aatype'][i].cpu()]) |
|
|
|
|
|
if len(seq_string) != final_pos[:, 1].shape[0]: |
|
|
seq_string = 'A'*final_pos[:, 1].shape[0] |
|
|
|
|
|
tm_score,_ = metrics.calc_tm_score( |
|
|
final_pos[:, 1], batch['all_atom_positions'][i].cpu().numpy()[:,1], |
|
|
seq_string, seq_string) |
|
|
|
|
|
valid_loss = {'rmsd_loss':rmsd, |
|
|
'tm_score':tm_score } |
|
|
|
|
|
batch_metrics.append((ca_ca_metrics | valid_loss)) |
|
|
saved_path = au.write_prot_to_pdb( |
|
|
final_pos, |
|
|
os.path.join( |
|
|
self._sample_write_dir, |
|
|
f'idx_{batch_idx}_len_{num_res}_rmsd_{rmsd}_tm_{tm_score}.pdb'), |
|
|
aatype=batch['aatype'][i].cpu().numpy(), |
|
|
no_indexing=True |
|
|
) |
|
|
if isinstance(self.logger, WandbLogger): |
|
|
self.validation_epoch_samples.append( |
|
|
[saved_path, self.global_step, wandb.Molecule(saved_path)] |
|
|
) |
|
|
|
|
|
batch_metrics = pd.DataFrame(batch_metrics) |
|
|
|
|
|
print("") |
|
|
for key in batch_metrics.columns: |
|
|
print('%s=%.3f ' %(key, np.mean(batch_metrics[key])), end='') |
|
|
|
|
|
self.validation_epoch_metrics.append(batch_metrics) |
|
|
|
|
|
def on_validation_epoch_end(self): |
|
|
if len(self.validation_epoch_samples) > 0: |
|
|
self.logger.log_table( |
|
|
key='valid/samples', |
|
|
columns=["sample_path", "global_step", "Protein"], |
|
|
data=self.validation_epoch_samples) |
|
|
self.validation_epoch_samples.clear() |
|
|
val_epoch_metrics = pd.concat(self.validation_epoch_metrics) |
|
|
for metric_name,metric_val in val_epoch_metrics.mean().to_dict().items(): |
|
|
if metric_name in ['rmsd_loss','tm_score']: |
|
|
self._log_scalar( |
|
|
f'valid/{metric_name}', |
|
|
metric_val, |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=True, |
|
|
batch_size=len(val_epoch_metrics), |
|
|
) |
|
|
else: |
|
|
self._log_scalar( |
|
|
f'valid/{metric_name}', |
|
|
metric_val, |
|
|
on_step=False, |
|
|
on_epoch=True, |
|
|
prog_bar=False, |
|
|
batch_size=len(val_epoch_metrics), |
|
|
) |
|
|
self.validation_epoch_metrics.clear() |
|
|
|
|
|
def _log_scalar( |
|
|
self, |
|
|
key, |
|
|
value, |
|
|
on_step=True, |
|
|
on_epoch=False, |
|
|
prog_bar=True, |
|
|
batch_size=None, |
|
|
sync_dist=False, |
|
|
rank_zero_only=True |
|
|
): |
|
|
if sync_dist and rank_zero_only: |
|
|
raise ValueError('Unable to sync dist when rank_zero_only=True') |
|
|
self.log( |
|
|
key, |
|
|
value, |
|
|
on_step=on_step, |
|
|
on_epoch=on_epoch, |
|
|
prog_bar=prog_bar, |
|
|
batch_size=batch_size, |
|
|
sync_dist=sync_dist, |
|
|
rank_zero_only=rank_zero_only |
|
|
) |
|
|
|
|
|
def training_step(self, batch: Any, stage: int): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_len = batch["aatype"].shape[1] |
|
|
batch_size = min(self._data_cfg.sampler.max_batch_size, self._data_cfg.sampler.max_num_res_squared // seq_len**2) |
|
|
|
|
|
for key,value in batch.items(): |
|
|
batch[key] = value.repeat((batch_size,)+(1,)*(len(value.shape)-1)) |
|
|
|
|
|
|
|
|
|
|
|
self.interpolant.set_device(batch['res_mask'].device) |
|
|
noisy_batch = self.interpolant.corrupt_batch(batch) |
|
|
if self._interpolant_cfg.self_condition and random.random() > 0.5: |
|
|
with torch.no_grad(): |
|
|
self.model.eval() |
|
|
model_sc = self.model(noisy_batch) |
|
|
noisy_batch['trans_sc'] = model_sc['pred_trans'] |
|
|
self.model.train() |
|
|
batch_losses = self.model_step(noisy_batch) |
|
|
num_batch = batch_losses['bb_atom_loss'].shape[0] |
|
|
total_losses = { |
|
|
k: torch.mean(v) for k,v in batch_losses.items() |
|
|
} |
|
|
for k,v in total_losses.items(): |
|
|
self._log_scalar( |
|
|
f"train/{k}", v, prog_bar=False, batch_size=num_batch) |
|
|
|
|
|
|
|
|
t = torch.squeeze(noisy_batch['t']) |
|
|
self._log_scalar( |
|
|
"train/t", |
|
|
np.mean(du.to_numpy(t)), |
|
|
prog_bar=False, batch_size=num_batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._log_scalar( |
|
|
"train/length", batch['res_mask'].shape[1], prog_bar=False, batch_size=num_batch) |
|
|
self._log_scalar( |
|
|
"train/batch_size", num_batch, prog_bar=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_loss = ( |
|
|
total_losses[self._exp_cfg.training.loss] |
|
|
) |
|
|
self._log_scalar( |
|
|
"train/loss", train_loss, batch_size=num_batch) |
|
|
|
|
|
|
|
|
return train_loss |
|
|
|
|
|
def configure_optimizers(self): |
|
|
return torch.optim.AdamW( |
|
|
params=filter(lambda p: p.requires_grad, self.model.parameters()), |
|
|
|
|
|
**self._exp_cfg.optimizer |
|
|
) |
|
|
|
|
|
|
|
|
def predict_step(self, batch, batch_idx): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.interpolant.set_device(batch['res_mask'].device) |
|
|
|
|
|
|
|
|
diffuse_mask = torch.ones(batch['aatype'].shape) |
|
|
|
|
|
filename = str(batch['filename'][0]) |
|
|
sample_dir = os.path.join( |
|
|
self._output_dir, f'batch_idx_{batch_idx}_{filename}') |
|
|
|
|
|
self.model.eval() |
|
|
atom37_traj, model_traj, _ = self.interpolant.sample( |
|
|
batch, self.model |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(sample_dir, exist_ok=True) |
|
|
for batch_index in range(atom37_traj[0].shape[0]): |
|
|
bb_traj = du.to_numpy(torch.stack(atom37_traj, dim=0))[:,batch_index] |
|
|
x0_traj = du.to_numpy(torch.stack(model_traj, dim=0))[:,batch_index] |
|
|
_ = eu.save_traj( |
|
|
bb_traj[-1], |
|
|
bb_traj, |
|
|
np.flip(x0_traj, axis=0), |
|
|
du.to_numpy(diffuse_mask)[batch_index], |
|
|
output_dir=sample_dir, |
|
|
aatype=batch['aatype'][batch_index].cpu().numpy(), |
|
|
index=batch_index, |
|
|
) |
|
|
|
|
|
with open(os.path.join(sample_dir,'seq.txt'), 'w') as f: |
|
|
f.write(batch['seq'][0]) |
|
|
|
|
|
|