File size: 2,312 Bytes
ca7299e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
"""Neural network for embedding node features."""
import torch
from torch import nn
from models.utils import get_index_embedding, get_time_embedding, add_RoPE
class NodeEmbedder(nn.Module):
def __init__(self, module_cfg):
super(NodeEmbedder, self).__init__()
self._cfg = module_cfg
self.c_s = self._cfg.c_s
self.c_pos_emb = self._cfg.c_pos_emb
self.c_timestep_emb = self._cfg.c_timestep_emb
self.c_node_pre = 1280
self.aatype_emb_dim = self._cfg.c_pos_emb
self.aatype_emb = nn.Embedding(21, self.aatype_emb_dim)
total_node_feats = self.aatype_emb_dim + self._cfg.c_timestep_emb + self.c_node_pre
# total_node_feats = self.aatype_emb_dim + self._cfg.c_timestep_emb
self.linear = nn.Sequential(
nn.Linear(total_node_feats, self.c_s),
nn.ReLU(),
nn.Dropout(self._cfg.dropout),
nn.Linear(self.c_s, self.c_s),
)
def embed_t(self, timesteps, mask):
timestep_emb = get_time_embedding(
timesteps[:, 0],
self.c_timestep_emb,
max_positions=2056
)[:, None, :].repeat(1, mask.shape[1], 1)
return timestep_emb * mask.unsqueeze(-1)
def forward(self, timesteps, aatype, node_repr_pre, mask):
'''
mask: [B,L]
timesteps: [B,1]
energy: [B,]
'''
b, num_res, device = mask.shape[0], mask.shape[1], mask.device
# [b, n_res, c_pos_emb]
# pos = torch.arange(num_res, dtype=torch.float32).to(device)[None] # (1,L)
# pos_emb = get_index_embedding(
# pos, self.c_pos_emb, max_len=2056
# )
# pos_emb = pos_emb.repeat([b, 1, 1])
# pos_emb = pos_emb * mask.unsqueeze(-1)
aatype_emb = self.aatype_emb(aatype) * mask.unsqueeze(-1)
# [b, n_res, c_timestep_emb]
input_feats = [aatype_emb]
# timesteps are between 0 and 1. Convert to integers.
time_emb = self.embed_t(timesteps, mask)
input_feats.append(time_emb)
input_feats.append(node_repr_pre)
out = self.linear(torch.cat(input_feats, dim=-1)) # (B,L,d_node)
return add_RoPE(out) |