PepFlow / models_con /node.py
Irwiny123's picture
添加PepFlow模型初始代码
ef423c5
import torch
from torch import nn
from pepflow.modules.common.geometry import construct_3d_basis, global_to_local, get_backbone_dihedral_angles
from pepflow.modules.common.layers import AngularEncoding
from pepflow.modules.protein.constants import BBHeavyAtom, AA
class NodeEmbedder(nn.Module):
def __init__(self, feat_dim, max_num_atoms, max_aa_types=22):
super().__init__()
self.max_num_atoms = max_num_atoms
self.max_aa_types = max_aa_types
self.feat_dim = feat_dim
self.aatype_embed = nn.Embedding(self.max_aa_types, feat_dim)
self.dihed_embed = AngularEncoding()
infeat_dim = feat_dim + (self.max_aa_types*max_num_atoms*3) + self.dihed_embed.get_out_dim(3)
self.mlp = nn.Sequential(
nn.Linear(infeat_dim, feat_dim * 2), nn.ReLU(),
nn.Linear(feat_dim * 2, feat_dim), nn.ReLU(),
nn.Linear(feat_dim, feat_dim), nn.ReLU(),
nn.Linear(feat_dim, feat_dim)
)
# def embed_t(self, timesteps, mask):
# timestep_emb = get_time_embedding(
# timesteps[:, 0],
# self.feat_dim,
# max_positions=2056
# )[:, None, :].repeat(1, mask.shape[1], 1)
# return timestep_emb
def forward(self, aa, res_nb, chain_nb, pos_atoms, mask_atoms, structure_mask=None, sequence_mask=None):
"""
Args:
aa: (N, L).
res_nb: (N, L).
chain_nb: (N, L).
pos_atoms: (N, L, A, 3).
mask_atoms: (N, L, A).
structure_mask: (N, L), mask out unknown structures to generate.
sequence_mask: (N, L), mask out unknown amino acids to generate.
"""
N, L = aa.size()
mask_residue = mask_atoms[:, :, BBHeavyAtom.CA] # (N, L)
# Remove other atoms
pos_atoms = pos_atoms[:, :, :self.max_num_atoms]
mask_atoms = mask_atoms[:, :, :self.max_num_atoms]
# Amino acid identity features
if sequence_mask is not None:
# Avoid data leakage at training time
aa = torch.where(sequence_mask, aa, torch.full_like(aa, fill_value=AA.UNK))
aa_feat = self.aatype_embed(aa) # (N, L, feat)
# Coordinate features
R = construct_3d_basis(
pos_atoms[:, :, BBHeavyAtom.CA],
pos_atoms[:, :, BBHeavyAtom.C],
pos_atoms[:, :, BBHeavyAtom.N]
)
t = pos_atoms[:, :, BBHeavyAtom.CA]
crd = global_to_local(R, t, pos_atoms) # (N, L, A, 3)
crd_mask = mask_atoms[:, :, :, None].expand_as(crd)
crd = torch.where(crd_mask, crd, torch.zeros_like(crd))
aa_expand = aa[:, :, None, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
rng_expand = torch.arange(0, self.max_aa_types)[None, None, :, None, None].expand(N, L, self.max_aa_types, self.max_num_atoms, 3).to(aa_expand)
place_mask = (aa_expand == rng_expand)
crd_expand = crd[:, :, None, :, :].expand(N, L, self.max_aa_types, self.max_num_atoms, 3)
crd_expand = torch.where(place_mask, crd_expand, torch.zeros_like(crd_expand))
crd_feat = crd_expand.reshape(N, L, self.max_aa_types*self.max_num_atoms*3)
if structure_mask is not None:
# Avoid data leakage at training time
crd_feat = crd_feat * structure_mask[:, :, None]
# Backbone dihedral features
bb_dihedral, mask_bb_dihed = get_backbone_dihedral_angles(pos_atoms, chain_nb=chain_nb, res_nb=res_nb, mask=mask_residue)
dihed_feat = self.dihed_embed(bb_dihedral[:, :, :, None]) * mask_bb_dihed[:, :, :, None] # (N, L, 3, dihed/3)
dihed_feat = dihed_feat.reshape(N, L, -1)
if structure_mask is not None:
# Avoid data leakage at training time
dihed_mask = torch.logical_and(
structure_mask,
torch.logical_and(
torch.roll(structure_mask, shifts=+1, dims=1),
torch.roll(structure_mask, shifts=-1, dims=1)
),
) # Avoid slight data leakage via dihedral angles of anchor residues
dihed_feat = dihed_feat * dihed_mask[:, :, None]
# # timestep
# timestep_emb = self.embed_t(timesteps, mask_residue)
out_feat = self.mlp(torch.cat([aa_feat, crd_feat, dihed_feat], dim=-1)) # (N, L, F)
out_feat = out_feat * mask_residue[:, :, None]
# print(f'aa_seq:{aa},aa:{aa_feat},crd:{crd_feat},dihed:{dihed_feat},time:{timestep_emb}')
# print(f'weight:{self.aatype_embed.weight}') # nan, why?
return out_feat