| from typing import Tuple, List |
| import torch |
| from torch import _dynamo |
| _dynamo.config.suppress_errors = True |
| from torch import Tensor, nn |
| import loralib as lora |
| import math |
| import esm |
| from ..module.utils import ( |
| NeighborEmbedding, |
| Distance, |
| DistanceV2, |
| rbf_class_mapping, |
| act_class_mapping |
| ) |
| from ..module.attention import ( |
| EquivariantMultiHeadAttention, |
| EquivariantMultiHeadAttentionSoftMax, |
| EquivariantPAEMultiHeadAttention, |
| EquivariantPAEMultiHeadAttentionSoftMax, |
| EquivariantWeightedPAEMultiHeadAttention, |
| EquivariantWeightedPAEMultiHeadAttentionSoftMax, |
| EquivariantPAEMultiHeadAttentionSoftMaxFullGraph, |
| MultiHeadAttentionSoftMaxFullGraph, |
| MSAEncoderFullGraph, |
| EquivariantTriAngularMultiHeadAttention, |
| EquivariantTriAngularStarMultiHeadAttention, |
| EquivariantTriAngularStarDropMultiHeadAttention, |
| EquivariantTriAngularDropMultiHeadAttention, |
| PairFeatureNet, |
| TriangularSelfAttentionBlock, |
| SeqPairAttentionOutput, |
| MSAEncoder, |
| ESMMultiheadAttention |
| ) |
|
|
| |
| class PassForward(nn.Module): |
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(PassForward, self).__init__() |
| self.x_in_channels = x_in_channels |
| self.x_channels = x_channels |
|
|
| def reset_parameters(self): |
| pass |
|
|
| def forward( |
| self, |
| x: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| edge_vec: Tensor = None, |
| edge_vec_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| |
| vec = node_vec_attr |
| attn_weight_layers = [] |
| return x, vec, pos, edge_attr, batch, attn_weight_layers |
|
|
| |
| class ESMTransformerLayer(nn.Module): |
| """Transformer layer block.""" |
|
|
| def __init__( |
| self, |
| embed_dim, |
| ffn_embed_dim, |
| attention_heads, |
| add_bias_kv=True, |
| use_esm1b_layer_norm=False, |
| use_rotary_embeddings: bool = False, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.ffn_embed_dim = ffn_embed_dim |
| self.attention_heads = attention_heads |
| self.use_rotary_embeddings = use_rotary_embeddings |
| self._init_submodules(add_bias_kv, use_esm1b_layer_norm) |
|
|
| def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm): |
| BertLayerNorm = nn.LayerNorm |
|
|
| self.self_attn = ESMMultiheadAttention( |
| self.embed_dim, |
| self.attention_heads, |
| add_bias_kv=add_bias_kv, |
| add_zero_attn=False, |
| use_rotary_embeddings=self.use_rotary_embeddings, |
| ) |
| self.self_attn_layer_norm = BertLayerNorm(self.embed_dim) |
|
|
| self.fc1 = lora.Linear(self.embed_dim, self.ffn_embed_dim, r=16) |
| self.fc2 = lora.Linear(self.ffn_embed_dim, self.embed_dim, r=16) |
|
|
| self.final_layer_norm = BertLayerNorm(self.embed_dim) |
|
|
| def gelu(self, x): |
| """Implementation of the gelu activation function. |
| |
| For information: OpenAI GPT's gelu is slightly different |
| (and gives slightly different results): |
| 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
| """ |
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
| def forward( |
| self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False |
| ): |
| residual = x |
| x = self.self_attn_layer_norm(x) |
| x, attn = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| key_padding_mask=self_attn_padding_mask, |
| need_weights=True, |
| need_head_weights=need_head_weights, |
| attn_mask=self_attn_mask, |
| ) |
| x = residual + x |
|
|
| residual = x |
| x = self.final_layer_norm(x) |
| x = self.gelu(self.fc1(x)) |
| x = self.fc2(x) |
| x = residual + x |
|
|
| return x, attn |
|
|
| |
| class LoRAESM2(nn.Module): |
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(LoRAESM2, self).__init__() |
| self.x_in_channels = x_in_channels |
| self.x_channels = 1280 |
| self.num_layers = 33 |
| self.embed_dim = 1280 |
| self.attention_heads = 20 |
| self.embed_scale = 1 |
| _, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| self.alphabet = alphabet |
| self.alphabet_size = len(alphabet) |
| self.padding_idx = alphabet.padding_idx |
| self.mask_idx = alphabet.mask_idx |
| self.cls_idx = alphabet.cls_idx |
| self.eos_idx = alphabet.eos_idx |
| self.prepend_bos = alphabet.prepend_bos |
| self.append_eos = alphabet.append_eos |
| self.token_dropout = True |
|
|
| |
| self.embed_tokens = lora.Embedding( |
| self.alphabet_size, |
| self.embed_dim, |
| padding_idx=self.padding_idx, |
| r=16, |
| ) |
| self.layers = nn.ModuleList( |
| [ |
| ESMTransformerLayer( |
| self.embed_dim, |
| 4 * self.embed_dim, |
| self.attention_heads, |
| add_bias_kv=False, |
| use_esm1b_layer_norm=True, |
| use_rotary_embeddings=True, |
| ) |
| for _ in range(self.num_layers) |
| ] |
| ) |
| self.emb_layer_norm_after = nn.LayerNorm(self.embed_dim) |
|
|
| def reset_parameters(self): |
| |
| esm_weights, _ = esm.pretrained.esm2_t33_650M_UR50D() |
| self.load_state_dict(esm_weights.state_dict(), strict=False) |
| |
| def forward( |
| self, |
| x: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| edge_vec: Tensor = None, |
| edge_vec_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| |
| vec = node_vec_attr |
| attn_weight_layers = [] |
| tokens = x |
| |
| assert tokens.ndim == 2 |
| padding_mask = tokens.eq(self.padding_idx) |
|
|
| x = self.embed_scale * self.embed_tokens(tokens) |
|
|
| if self.token_dropout: |
| x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0) |
| |
| mask_ratio_train = 0.15 * 0.8 |
| src_lengths = (~padding_mask).sum(-1) |
| mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths |
| x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] |
|
|
| if padding_mask is not None: |
| x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| if not padding_mask.any(): |
| padding_mask = None |
|
|
| for _, layer in enumerate(self.layers): |
| x, attn = layer( |
| x, |
| self_attn_padding_mask=padding_mask, |
| need_head_weights=False, |
| ) |
| attn_weight_layers.append(attn) |
|
|
| x = self.emb_layer_norm_after(x) |
| x = x.transpose(0, 1) |
|
|
| return x, vec, pos, edge_attr, batch, attn_weight_layers |
|
|
| |
| class eqTransformer(nn.Module): |
| """The equivariant Transformer architecture. |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| share_kv=False, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqTransformer, self).__init__() |
|
|
| assert distance_influence in ["keys", "values", "both", "none"] |
| assert rbf_type in rbf_class_mapping, ( |
| f'Unknown RBF type "{rbf_type}". ' |
| f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
| ) |
| assert activation in act_class_mapping, ( |
| f'Unknown activation function "{activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
| assert attn_activation in act_class_mapping, ( |
| f'Unknown attention activation function "{attn_activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
|
|
| self.x_in_channels = x_in_channels |
| self.x_channels = x_channels |
| self.vec_in_channels = vec_in_channels |
| self.vec_channels = vec_channels |
| self.x_hidden_channels = x_hidden_channels |
| self.vec_hidden_channels = vec_hidden_channels |
| self.share_kv = share_kv |
| self.num_layers = num_layers |
| self.num_rbf = num_rbf |
| self.num_edge_attr = num_edge_attr |
| self.rbf_type = rbf_type |
| self.trainable_rbf = trainable_rbf |
| self.activation = activation |
| self.attn_activation = attn_activation |
| self.neighbor_embedding = neighbor_embedding |
| self.num_heads = num_heads |
| self.distance_influence = distance_influence |
| self.cutoff_lower = cutoff_lower |
| self.cutoff_upper = cutoff_upper |
| self.use_lora = use_lora |
| self.use_msa = x_use_msa |
|
|
| self.distance = Distance( |
| cutoff_lower, |
| cutoff_upper, |
| return_vecs=True, |
| loop=True, |
| ) |
| self.distance_expansion = rbf_class_mapping[rbf_type]( |
| cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
| ) |
| self.neighbor_embedding = ( |
| NeighborEmbedding( |
| x_channels, num_rbf + num_edge_attr, cutoff_lower, cutoff_upper, |
| ) |
| if neighbor_embedding |
| else None |
| ) |
| self.msa_encoder = MSAEncoder( |
| num_species=199, |
| weighting_schema='spe', |
| pairwise_type='cov', |
| ) if x_use_msa else None |
|
|
| self.node_x_proj = None |
| if x_in_channels is not None: |
| if x_in_embedding_type == "Linear": |
| self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
| elif x_in_embedding_type == "Linear_gelu": |
| self.node_x_proj = nn.Sequential( |
| nn.Linear(x_in_channels, x_channels), |
| nn.GELU(), |
| ) |
| else: |
| self.node_x_proj = nn.Embedding(x_in_channels, x_channels) |
| self.node_vec_proj = nn.Linear( |
| vec_in_channels, vec_channels, bias=False) |
|
|
| self.attention_layers = nn.ModuleList() |
| self._set_attn_layers() |
| self.drop = nn.Dropout(drop_out_rate) |
| self.out_norm = nn.LayerNorm(x_channels) |
|
|
| self.reset_parameters() |
|
|
| def _set_attn_layers(self): |
| for _ in range(self.num_layers): |
| layer = EquivariantMultiHeadAttention( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_channels, |
| vec_hidden_channels=self.vec_hidden_channels, |
| share_kv=self.share_kv, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| ) |
| self.attention_layers.append(layer) |
|
|
| def reset_parameters(self): |
| self.distance_expansion.reset_parameters() |
| if self.neighbor_embedding is not None: |
| self.neighbor_embedding.reset_parameters() |
| for attn in self.attention_layers: |
| attn.reset_parameters() |
| self.out_norm.reset_parameters() |
|
|
| def forward( |
| self, |
| x: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| edge_vec: Tensor = None, |
| edge_vec_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| if edge_vec is None: |
| edge_index, edge_weight, edge_vec = self.distance(pos, edge_index) |
| assert ( |
| edge_vec is not None |
| ), "Distance module did not return directional information" |
| |
| edge_attr_distance = self.distance_expansion( |
| edge_weight) |
| |
| |
| edge_attr = torch.cat([edge_attr, edge_attr_distance], dim=-1) |
| |
| if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
| if self.node_x_proj is not None: |
| x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
| else: |
| x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
| else: |
| x_msa = None |
| |
| |
| if self.msa_encoder is not None and x_msa is not None: |
| _, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
| edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
| _, msa_edge_attr = self.msa_encoder(x_msa, edge_index) |
| edge_attr = torch.cat([edge_attr, msa_edge_attr], dim=-1) |
| mask = edge_index[0] != edge_index[1] |
| edge_vec[mask] = edge_vec[mask] / \ |
| torch.norm(edge_vec[mask], dim=1).unsqueeze(1) |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
|
| if self.neighbor_embedding is not None: |
| x = self.neighbor_embedding(x, edge_index, edge_weight, edge_attr) |
| |
| vec = self.node_vec_proj(node_vec_attr) if node_vec_attr is not None \ |
| else torch.zeros(x.size(0), 3, self.vec_channels, device=x.device) |
|
|
| attn_weight_layers = [] |
| for attn in self.attention_layers: |
| dx, dvec, attn_weight = attn( |
| x, vec, edge_index, edge_weight, edge_attr, edge_vec) |
| x = x + self.drop(dx) |
| vec = vec + self.drop(dvec) |
| if return_attn: |
| attn_weight_layers.append(attn_weight) |
| x = self.out_norm(x) |
|
|
| return x, vec, pos, edge_attr, batch, attn_weight_layers |
|
|
| def __repr__(self): |
| return ( |
| f"{self.__class__.__name__}(" |
| f"x_channels={self.x_channels}, " |
| f"x_hidden_channels={self.x_hidden_channels}, " |
| f"vec_in_channels={self.vec_in_channels}, " |
| f"vec_channels={self.vec_channels}, " |
| f"vec_hidden_channels={self.vec_hidden_channels}, " |
| f"num_layers={self.num_layers}, " |
| f"num_rbf={self.num_rbf}, " |
| f"rbf_type={self.rbf_type}, " |
| f"trainable_rbf={self.trainable_rbf}, " |
| f"activation={self.activation}, " |
| f"attn_activation={self.attn_activation}, " |
| f"neighbor_embedding={self.neighbor_embedding}, " |
| f"num_heads={self.num_heads}, " |
| f"distance_influence={self.distance_influence}, " |
| f"cutoff_lower={self.cutoff_lower}, " |
| f"cutoff_upper={self.cutoff_upper})" |
| ) |
|
|
|
|
| |
| class eqStarTransformer(eqTransformer): |
| """The equivariant Transformer architecture. |
| First Layer is Star Graph, next layer is full graph |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| share_kv=False, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqStarTransformer, self).__init__(x_in_channels=x_in_channels, |
| x_channels=x_channels, |
| x_hidden_channels=x_hidden_channels, |
| vec_in_channels=vec_in_channels, |
| vec_channels=vec_channels, |
| vec_hidden_channels=vec_hidden_channels, |
| share_kv=share_kv, |
| num_layers=num_layers, |
| num_edge_attr=num_edge_attr, |
| num_rbf=num_rbf, |
| rbf_type=rbf_type, |
| trainable_rbf=trainable_rbf, |
| activation=activation, |
| attn_activation=attn_activation, |
| neighbor_embedding=neighbor_embedding, |
| num_heads=num_heads, |
| distance_influence=distance_influence, |
| cutoff_lower=cutoff_lower, |
| cutoff_upper=cutoff_upper, |
| x_in_embedding_type=x_in_embedding_type, |
| x_use_msa=x_use_msa, |
| drop_out_rate=drop_out_rate, |
| use_lora=use_lora) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| edge_index, edge_weight, edge_vec = self.distance(pos, edge_index) |
| edge_index_star, edge_weight_star, edge_vec_star = self.distance( |
| pos, edge_index_star) |
|
|
| assert ( |
| edge_vec is not None and edge_vec_star is not None |
| ), "Distance module did not return directional information" |
| |
| edge_attr_distance = self.distance_expansion( |
| edge_weight) |
| edge_attr_distance_star = self.distance_expansion( |
| edge_weight_star) |
| |
| if edge_attr is not None: |
| |
| edge_attr = torch.cat([edge_attr, edge_attr_distance], dim=-1) |
| else: |
| edge_attr = edge_attr_distance |
| if edge_attr_star is not None: |
| edge_attr_star = torch.cat( |
| [edge_attr_star, edge_attr_distance_star], dim=-1) |
| else: |
| edge_attr_star = edge_attr_distance_star |
| |
| if self.node_x_proj is not None: |
| if x.shape[1] > self.x_in_channels: |
| x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
| else: |
| x_msa = None |
| elif x.shape[1] > self.x_channels: |
| x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
| else: |
| x_msa = None |
| |
| |
| if self.msa_encoder is not None and x_msa is not None: |
| _, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
| edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
| |
| |
| |
| |
| mask = edge_index[0] != edge_index[1] |
| edge_vec[mask] = edge_vec[mask] / \ |
| torch.norm(edge_vec[mask], dim=1).unsqueeze(1) |
| mask = edge_index_star[0] != edge_index_star[1] |
| edge_vec_star[mask] = edge_vec_star[mask] / \ |
| torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
| if self.neighbor_embedding is not None: |
| |
| x = self.neighbor_embedding( |
| x, edge_index_star, edge_weight_star, edge_attr_star) |
| |
| vec = self.node_vec_proj(node_vec_attr) if node_vec_attr is not None \ |
| else torch.zeros(x.size(0), 3, self.vec_channels, device=x.device) |
|
|
| attn_weight_layers = [] |
| for i, attn in enumerate(self.attention_layers): |
| |
| if i == 0: |
| dx, dvec, attn_weight = attn(x, vec, |
| edge_index_star, edge_weight_star, edge_attr_star, edge_vec_star, |
| return_attn=return_attn) |
| else: |
| dx, dvec, attn_weight = attn(x, vec, |
| edge_index, edge_weight, edge_attr, edge_vec, |
| return_attn=return_attn) |
| x = x + self.drop(dx) |
| vec = vec + self.drop(dvec) |
| if return_attn: |
| attn_weight_layers.append(attn_weight) |
| x = self.out_norm(x) |
| |
| |
| |
| |
| |
| return x, vec, pos, edge_attr_star, batch, attn_weight_layers |
|
|
|
|
| |
| class eqTransformerSoftMax(eqTransformer): |
| """The equivariant Transformer architecture. |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqTransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
| x_channels=x_channels, |
| x_hidden_channels=x_hidden_channels, |
| vec_in_channels=vec_in_channels, |
| vec_channels=vec_channels, |
| vec_hidden_channels=vec_hidden_channels, |
| num_layers=num_layers, |
| num_edge_attr=num_edge_attr, |
| num_rbf=num_rbf, |
| rbf_type=rbf_type, |
| trainable_rbf=trainable_rbf, |
| activation=activation, |
| attn_activation=attn_activation, |
| neighbor_embedding=neighbor_embedding, |
| num_heads=num_heads, |
| distance_influence=distance_influence, |
| cutoff_lower=cutoff_lower, |
| cutoff_upper=cutoff_upper, |
| x_in_embedding_type=x_in_embedding_type, |
| x_use_msa=x_use_msa, |
| drop_out_rate=drop_out_rate, |
| use_lora=use_lora) |
|
|
| def _set_attn_layers(self): |
| for _ in range(self.num_layers): |
| layer = EquivariantMultiHeadAttentionSoftMax( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_channels, |
| vec_hidden_channels=self.vec_hidden_channels, |
| share_kv=self.share_kv, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| ) |
| self.attention_layers.append(layer) |
|
|
|
|
| |
| class eqStarTransformerSoftMax(eqStarTransformer): |
| """The equivariant Transformer architecture. |
| First Layer is Star Graph, next layer is full graph |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| share_kv=False, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqStarTransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
| x_channels=x_channels, |
| x_hidden_channels=x_hidden_channels, |
| vec_in_channels=vec_in_channels, |
| vec_channels=vec_channels, |
| vec_hidden_channels=vec_hidden_channels, |
| share_kv=share_kv, |
| num_layers=num_layers, |
| num_edge_attr=num_edge_attr, |
| num_rbf=num_rbf, |
| rbf_type=rbf_type, |
| trainable_rbf=trainable_rbf, |
| activation=activation, |
| attn_activation=attn_activation, |
| neighbor_embedding=neighbor_embedding, |
| num_heads=num_heads, |
| distance_influence=distance_influence, |
| cutoff_lower=cutoff_lower, |
| cutoff_upper=cutoff_upper, |
| x_in_embedding_type=x_in_embedding_type, |
| x_use_msa=x_use_msa, |
| drop_out_rate=drop_out_rate, |
| use_lora=use_lora) |
|
|
| def _set_attn_layers(self): |
| for _ in range(self.num_layers): |
| layer = EquivariantMultiHeadAttentionSoftMax( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_channels, |
| vec_hidden_channels=self.vec_hidden_channels, |
| share_kv=self.share_kv, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| ) |
| self.attention_layers.append(layer) |
|
|
|
|
| class eqStar2TransformerSoftMax(eqStarTransformer): |
| """The equivariant Transformer architecture. |
| First Layer is Star Graph, next layer is full graph |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| share_kv=False, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqStar2TransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
| x_channels=x_channels, |
| x_hidden_channels=x_hidden_channels, |
| vec_in_channels=vec_in_channels, |
| vec_channels=vec_channels, |
| vec_hidden_channels=vec_hidden_channels, |
| share_kv=share_kv, |
| num_layers=num_layers, |
| num_edge_attr=num_edge_attr, |
| num_rbf=num_rbf, |
| rbf_type=rbf_type, |
| trainable_rbf=trainable_rbf, |
| activation=activation, |
| attn_activation=attn_activation, |
| neighbor_embedding=neighbor_embedding, |
| num_heads=num_heads, |
| distance_influence=distance_influence, |
| cutoff_lower=cutoff_lower, |
| cutoff_upper=cutoff_upper, |
| x_in_embedding_type=x_in_embedding_type, |
| x_use_msa=x_use_msa, |
| drop_out_rate=drop_out_rate, |
| use_lora=use_lora) |
|
|
| def _set_attn_layers(self): |
| assert self.num_layers > 0, "num_layers must be greater than 0" |
| |
| self.attention_layers.append( |
| EquivariantMultiHeadAttention( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_channels, |
| vec_hidden_channels=self.vec_hidden_channels, |
| share_kv=self.share_kv, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| use_lora=self.use_lora, |
| ) |
| ) |
| |
| for _ in range(self.num_layers - 1): |
| layer = EquivariantMultiHeadAttentionSoftMax( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_channels, |
| vec_hidden_channels=self.vec_hidden_channels, |
| share_kv=self.share_kv, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr - 442 if self.use_msa else self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| use_lora=self.use_lora, |
| ) |
| self.attention_layers.append(layer) |
|
|
|
|
| class eqStar2PAETransformerSoftMax(eqStar2TransformerSoftMax): |
| """The equivariant Transformer architecture. |
| First Layer is Star Graph, next layer is full graph |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| share_kv=False, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqStar2PAETransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
| x_channels=x_channels, |
| x_hidden_channels=x_hidden_channels, |
| vec_in_channels=vec_in_channels, |
| vec_channels=vec_channels, |
| vec_hidden_channels=vec_hidden_channels, |
| share_kv=share_kv, |
| num_layers=num_layers, |
| num_edge_attr=num_edge_attr, |
| num_rbf=num_rbf, |
| rbf_type=rbf_type, |
| trainable_rbf=trainable_rbf, |
| activation=activation, |
| attn_activation=attn_activation, |
| neighbor_embedding=neighbor_embedding, |
| num_heads=num_heads, |
| distance_influence=distance_influence, |
| cutoff_lower=cutoff_lower, |
| cutoff_upper=cutoff_upper, |
| x_in_embedding_type=x_in_embedding_type, |
| x_use_msa=x_use_msa, |
| drop_out_rate=drop_out_rate, |
| use_lora=use_lora) |
| |
| self.neighbor_embedding = ( |
| NeighborEmbedding( |
| x_channels, num_edge_attr, |
| cutoff_lower, cutoff_upper, |
| ) |
| if neighbor_embedding |
| else None |
| ) |
| self.neighbor_embedding.reset_parameters() |
|
|
| def _set_attn_layers(self): |
| assert self.num_layers > 0, "num_layers must be greater than 0" |
| |
| self.attention_layers.append( |
| EquivariantPAEMultiHeadAttention( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_channels, |
| vec_hidden_channels=self.vec_hidden_channels, |
| share_kv=self.share_kv, |
| edge_attr_dist_channels=self.num_rbf, |
| edge_attr_channels=self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| use_lora=self.use_lora, |
| ) |
| ) |
| |
| for _ in range(self.num_layers - 1): |
| layer = EquivariantPAEMultiHeadAttentionSoftMax( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_channels, |
| vec_hidden_channels=self.vec_hidden_channels, |
| share_kv=self.share_kv, |
| edge_attr_dist_channels=self.num_rbf, |
| edge_attr_channels=self.num_edge_attr - 442 if self.use_msa else self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| use_lora=self.use_lora, |
| ) |
| self.attention_layers.append(layer) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| plddt: Tensor = None, |
| edge_confidence: Tensor = None, |
| edge_confidence_star: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| edge_index, edge_weight, edge_vec = self.distance(pos, edge_index) |
| edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, edge_index_star) |
|
|
| assert ( |
| edge_vec is not None and edge_vec_star is not None |
| ), "Distance module did not return directional information" |
| |
| edge_attr_distance = self.distance_expansion( |
| edge_weight) |
| edge_attr_distance_star = self.distance_expansion( |
| edge_weight_star) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if self.node_x_proj is not None: |
| if x.shape[1] > self.x_in_channels: |
| x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
| else: |
| x_msa = None |
| elif x.shape[1] > self.x_channels: |
| x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
| else: |
| x_msa = None |
| |
| |
| if self.msa_encoder is not None and x_msa is not None: |
| _, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
| if edge_attr_star is not None: |
| edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
| else: |
| edge_attr_star = msa_edge_attr_star |
| |
| |
| |
| |
| |
| |
| mask = edge_index[0] != edge_index[1] |
| edge_vec[mask] = edge_vec[mask] / \ |
| torch.norm(edge_vec[mask], dim=1).unsqueeze(1) |
| mask = edge_index_star[0] != edge_index_star[1] |
| edge_vec_star[mask] = edge_vec_star[mask] / \ |
| torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
| if self.neighbor_embedding is not None: |
| |
| x = self.neighbor_embedding( |
| x, edge_index_star, edge_weight_star, edge_attr_star) |
| |
| vec = self.node_vec_proj(node_vec_attr) if node_vec_attr is not None \ |
| else torch.zeros(x.size(0), 3, self.vec_channels, device=x.device) |
|
|
| attn_weight_layers = [] |
| for i, attn in enumerate(self.attention_layers): |
| |
| if i == 0: |
| dx, dvec, attn_weight = attn(x, vec, |
| edge_index_star, edge_confidence_star, |
| edge_attr_distance_star, edge_attr_star, |
| edge_vec_star, plddt, |
| return_attn=return_attn) |
| else: |
| dx, dvec, attn_weight = attn(x, vec, |
| edge_index, edge_confidence, |
| edge_attr_distance, edge_attr, |
| edge_vec, plddt, |
| return_attn=return_attn) |
| x = x + self.drop(dx) |
| vec = vec + self.drop(dvec) |
| if return_attn: |
| attn_weight_layers.append(attn_weight) |
| x = self.out_norm(x) |
| return x, vec, pos, edge_attr_star, batch, attn_weight_layers |
|
|
|
|
| class eqStar2FullGraphPAETransformerSoftMax(nn.Module): |
| """The equivariant Transformer architecture. |
| First Layer is Star Graph, next layer is full graph |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| share_kv=False, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqStar2FullGraphPAETransformerSoftMax, self).__init__() |
|
|
| assert distance_influence in ["keys", "values", "both", "none"] |
| assert rbf_type in rbf_class_mapping, ( |
| f'Unknown RBF type "{rbf_type}". ' |
| f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
| ) |
| assert activation in act_class_mapping, ( |
| f'Unknown activation function "{activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
| assert attn_activation in act_class_mapping, ( |
| f'Unknown attention activation function "{attn_activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
|
|
| self.x_in_channels = x_in_channels |
| self.x_channels = x_channels |
| self.vec_in_channels = vec_in_channels |
| self.vec_channels = vec_channels |
| self.x_hidden_channels = x_hidden_channels |
| self.vec_hidden_channels = vec_hidden_channels |
| self.share_kv = share_kv |
| self.num_layers = num_layers |
| self.num_rbf = num_rbf |
| self.num_edge_attr = num_edge_attr |
| self.rbf_type = rbf_type |
| self.trainable_rbf = trainable_rbf |
| self.activation = activation |
| self.attn_activation = attn_activation |
| self.neighbor_embedding = neighbor_embedding |
| self.num_heads = num_heads |
| self.distance_influence = distance_influence |
| self.cutoff_lower = cutoff_lower |
| self.cutoff_upper = cutoff_upper |
| self.use_lora = use_lora |
| self.use_msa = x_use_msa |
|
|
| self.distance = Distance( |
| cutoff_lower, |
| cutoff_upper, |
| return_vecs=True, |
| loop=True, |
| ) |
| self.distance_expansion = rbf_class_mapping[rbf_type]( |
| cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
| ) |
| self.neighbor_embedding = None |
| self.msa_encoder = MSAEncoderFullGraph( |
| num_species=199, |
| weighting_schema='spe', |
| pairwise_type='cov', |
| ) if x_use_msa else None |
|
|
| self.node_x_proj = None |
| if x_in_channels is not None: |
| if x_in_embedding_type == "Linear": |
| self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
| elif x_in_embedding_type == "Linear_gelu": |
| self.node_x_proj = nn.Sequential( |
| nn.Linear(x_in_channels, x_channels), |
| nn.GELU(), |
| ) |
| else: |
| self.node_x_proj = nn.Embedding(x_in_channels, x_channels) |
| self.node_vec_proj = nn.Linear( |
| vec_in_channels, vec_channels, bias=False) |
|
|
| self.attention_layers = nn.ModuleList() |
| self._set_attn_layers() |
| self.drop = nn.Dropout(drop_out_rate) |
| self.out_norm = nn.LayerNorm(x_channels) |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| self.distance_expansion.reset_parameters() |
| if self.neighbor_embedding is not None: |
| self.neighbor_embedding.reset_parameters() |
| for attn in self.attention_layers: |
| attn.reset_parameters() |
| self.out_norm.reset_parameters() |
|
|
| def _set_attn_layers(self): |
| assert self.num_layers > 0, "num_layers must be greater than 0" |
| |
| |
| input_dic = { |
| "x_channels": self.x_channels, |
| "x_hidden_channels": self.x_hidden_channels, |
| "vec_channels": self.vec_channels, |
| "vec_hidden_channels": self.vec_hidden_channels, |
| "share_kv": self.share_kv, |
| "edge_attr_dist_channels": self.num_rbf, |
| "edge_attr_channels": self.num_edge_attr, |
| "distance_influence": self.distance_influence, |
| "num_heads": self.num_heads, |
| "activation": act_class_mapping[self.activation], |
| "attn_activation": self.attn_activation, |
| "cutoff_lower": self.cutoff_lower, |
| "cutoff_upper": self.cutoff_upper, |
| "use_lora": self.use_lora |
| } |
| for _ in range(self.num_layers): |
| layer = EquivariantPAEMultiHeadAttentionSoftMaxFullGraph(**input_dic) |
| self.attention_layers.append(layer) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| pos: Tensor, |
| batch: Tensor = None, |
| x_padding_mask: Tensor = None, |
| edge_index: Tensor = None, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| plddt: Tensor = None, |
| edge_confidence: Tensor = None, |
| edge_confidence_star: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| edge_vec = pos[:, :, None, :] - pos[:, None, :, :] |
| edge_weight = torch.norm(edge_vec, dim=-1) |
| |
| |
| edge_attr_distance = self.distance_expansion(edge_weight) |
| |
| |
| x, x_msa = x[..., :self.x_in_channels], x[..., self.x_in_channels:] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _, msa_edge_attr = self.msa_encoder(x_msa) |
| |
| edge_attr = torch.cat([edge_attr, msa_edge_attr], dim=-1) |
| |
| |
| |
| mask = torch.ones((edge_vec.shape[0], edge_vec.shape[1], edge_vec.shape[2]), device=edge_vec.device, dtype=torch.bool)^torch.eye(edge_vec.shape[1], device=edge_vec.device, dtype=torch.bool).unsqueeze(0) |
| edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask] + 1e-12, dim=-1).unsqueeze(-1) |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
| |
| vec = self.node_vec_proj(node_vec_attr) if node_vec_attr is not None \ |
| else torch.zeros(x.size(0), 3, self.vec_channels, device=x.device) |
|
|
| attn_weight_layers = [] |
| for i, attn in enumerate(self.attention_layers): |
| |
| dx, dvec, attn_weight = attn(x, vec, |
| edge_index, edge_confidence, |
| edge_attr_distance, edge_attr, |
| edge_vec, plddt, x_padding_mask, |
| return_attn=return_attn) |
| x = x + self.drop(dx) |
| vec = vec + self.drop(dvec) |
| if return_attn: |
| attn_weight_layers.append(attn_weight) |
| x = self.out_norm(x) |
| return x, vec, pos, [edge_confidence, edge_attr_distance, edge_attr, plddt], batch, attn_weight_layers |
|
|
|
|
| class FullGraphPAETransformerSoftMax(eqStar2FullGraphPAETransformerSoftMax): |
| """The equivariant Transformer architecture. |
| First Layer is Star Graph, next layer is full graph |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| share_kv=False, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(FullGraphPAETransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
| x_channels=x_channels, |
| x_hidden_channels=x_hidden_channels, |
| vec_in_channels=vec_in_channels, |
| vec_channels=vec_channels, |
| vec_hidden_channels=vec_hidden_channels, |
| share_kv=share_kv, |
| num_layers=num_layers, |
| num_edge_attr=num_edge_attr, |
| num_rbf=num_rbf, |
| rbf_type=rbf_type, |
| trainable_rbf=trainable_rbf, |
| activation=activation, |
| attn_activation=attn_activation, |
| neighbor_embedding=neighbor_embedding, |
| num_heads=num_heads, |
| distance_influence=distance_influence, |
| cutoff_lower=cutoff_lower, |
| cutoff_upper=cutoff_upper, |
| x_in_embedding_type=x_in_embedding_type, |
| x_use_msa=x_use_msa, |
| drop_out_rate=drop_out_rate, |
| use_lora=use_lora) |
|
|
| def _set_attn_layers(self): |
| assert self.num_layers > 0, "num_layers must be greater than 0" |
| |
| |
| input_dic = { |
| "x_channels": self.x_channels, |
| "x_hidden_channels": self.x_hidden_channels, |
| "vec_channels": self.vec_channels, |
| "vec_hidden_channels": self.vec_hidden_channels, |
| "share_kv": self.share_kv, |
| "edge_attr_dist_channels": self.num_rbf, |
| "edge_attr_channels": self.num_edge_attr, |
| "distance_influence": self.distance_influence, |
| "num_heads": self.num_heads, |
| "activation": act_class_mapping[self.activation], |
| "attn_activation": self.attn_activation, |
| "cutoff_lower": self.cutoff_lower, |
| "cutoff_upper": self.cutoff_upper, |
| "use_lora": self.use_lora |
| } |
| for _ in range(self.num_layers): |
| layer = MultiHeadAttentionSoftMaxFullGraph(**input_dic) |
| self.attention_layers.append(layer) |
|
|
|
|
| class eqStar2WeightedPAETransformerSoftMax(eqStar2PAETransformerSoftMax): |
| """The equivariant Transformer architecture. |
| First Layer is Star Graph, next layer is full graph |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| share_kv=False, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnorm", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=True, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqStar2WeightedPAETransformerSoftMax, self).__init__(x_in_channels=x_in_channels, |
| x_channels=x_channels, |
| x_hidden_channels=x_hidden_channels, |
| vec_in_channels=vec_in_channels, |
| vec_channels=vec_channels, |
| vec_hidden_channels=vec_hidden_channels, |
| share_kv=share_kv, |
| num_layers=num_layers, |
| num_edge_attr=num_edge_attr, |
| num_rbf=num_rbf, |
| rbf_type=rbf_type, |
| trainable_rbf=trainable_rbf, |
| activation=activation, |
| attn_activation=attn_activation, |
| neighbor_embedding=neighbor_embedding, |
| num_heads=num_heads, |
| distance_influence=distance_influence, |
| cutoff_lower=cutoff_lower, |
| cutoff_upper=cutoff_upper, |
| x_in_embedding_type=x_in_embedding_type, |
| x_use_msa=x_use_msa, |
| drop_out_rate=drop_out_rate, |
| use_lora=use_lora) |
| |
| def _set_attn_layers(self): |
| assert self.num_layers > 0, "num_layers must be greater than 0" |
| |
| self.attention_layers.append( |
| EquivariantWeightedPAEMultiHeadAttention( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_channels, |
| vec_hidden_channels=self.vec_hidden_channels, |
| share_kv=self.share_kv, |
| edge_attr_dist_channels=self.num_rbf, |
| edge_attr_channels=self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| use_lora=self.use_lora, |
| ) |
| ) |
| |
| for _ in range(self.num_layers - 1): |
| layer = EquivariantWeightedPAEMultiHeadAttentionSoftMax( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_channels, |
| vec_hidden_channels=self.vec_hidden_channels, |
| share_kv=self.share_kv, |
| edge_attr_dist_channels=self.num_rbf, |
| edge_attr_channels=self.num_edge_attr - 442 if self.use_msa else self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| use_lora=self.use_lora, |
| ) |
| self.attention_layers.append(layer) |
|
|
|
|
| class eqTriStarTransformer(nn.Module): |
| """The equivariant Transformer architecture. |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnormunlim", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=False, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=False, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqTriStarTransformer, self).__init__() |
|
|
| assert distance_influence in ["keys", "values", "both", "none"] |
| assert rbf_type in rbf_class_mapping, ( |
| f'Unknown RBF type "{rbf_type}". ' |
| f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
| ) |
| assert activation in act_class_mapping, ( |
| f'Unknown activation function "{activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
| assert attn_activation in act_class_mapping, ( |
| f'Unknown attention activation function "{attn_activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
|
|
| self.x_in_channels = x_in_channels |
| self.x_channels = x_channels |
| self.vec_in_channels = vec_in_channels |
| self.vec_channels = vec_channels |
| self.x_hidden_channels = x_hidden_channels |
| self.vec_hidden_channels = vec_hidden_channels |
| self.num_layers = num_layers |
| self.num_rbf = num_rbf |
| self.num_edge_attr = num_edge_attr |
| self.rbf_type = rbf_type |
| self.trainable_rbf = trainable_rbf |
| self.activation = activation |
| self.attn_activation = attn_activation |
| self.neighbor_embedding = neighbor_embedding |
| self.num_heads = num_heads |
| self.distance_influence = distance_influence |
| self.cutoff_lower = cutoff_lower |
| self.cutoff_upper = cutoff_upper |
|
|
| self.distance = DistanceV2( |
| return_vecs=True, |
| loop=True, |
| ) |
| self.distance_expansion = rbf_class_mapping[rbf_type]( |
| cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
| ) |
| |
| self.node_x_proj = None |
| if x_in_channels is not None: |
| self.node_x_proj = nn.Linear(x_in_channels, x_channels) if x_in_embedding_type == "Linear" \ |
| else nn.Embedding(x_in_channels, x_channels) |
|
|
| self.attention_layers = nn.ModuleList() |
| self._set_attn_layers() |
| self.drop = nn.Dropout(drop_out_rate) |
| self.out_norm = nn.LayerNorm(x_channels) |
|
|
| self.reset_parameters() |
|
|
| def _set_attn_layers(self): |
| for _ in range(self.num_layers): |
| layer = EquivariantTriAngularMultiHeadAttention( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_in_channels, |
| vec_hidden_channels=self.vec_channels, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| ) |
| self.attention_layers.append(layer) |
|
|
| def reset_parameters(self): |
| self.distance_expansion.reset_parameters() |
| for attn in self.attention_layers: |
| attn.reset_parameters() |
| self.out_norm.reset_parameters() |
|
|
| def forward( |
| self, |
| x: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| coords = node_vec_attr + pos.unsqueeze(2) |
| edge_index, edge_weight, edge_vec = self.distance(pos, coords, edge_index) |
| edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
| assert ( |
| edge_vec is not None |
| ), "Distance module did not return directional information" |
| |
| |
| |
| |
| |
| |
| edge_attr = torch.cat([edge_attr, self.distance_expansion(edge_weight)], dim=-1) |
| edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
| mask = edge_index[0] != edge_index[1] |
| edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask], dim=1).unsqueeze(1) |
| mask = edge_index_star[0] != edge_index_star[1] |
| edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
| del mask, edge_weight, edge_weight_star |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
|
| attn_weight_layers = [] |
| for i, attn in enumerate(self.attention_layers): |
| if i == 0: |
| dx, edge_attr_star, attn_weight = attn( |
| x, edge_index_star, edge_attr_star, edge_vec_star) |
| else: |
| dx, edge_attr, attn_weight = attn( |
| x, edge_index, edge_attr, edge_vec) |
| x = x + self.drop(dx) |
| if return_attn: |
| attn_weight_layers.append(attn_weight) |
| x = self.out_norm(x) |
| return x, None, pos, edge_attr, batch, attn_weight_layers |
|
|
| def __repr__(self): |
| return ( |
| f"{self.__class__.__name__}(" |
| f"x_channels={self.x_channels}, " |
| f"x_hidden_channels={self.x_hidden_channels}, " |
| f"vec_in_channels={self.vec_in_channels}, " |
| f"vec_channels={self.vec_channels}, " |
| f"vec_hidden_channels={self.vec_hidden_channels}, " |
| f"num_layers={self.num_layers}, " |
| f"num_rbf={self.num_rbf}, " |
| f"rbf_type={self.rbf_type}, " |
| f"trainable_rbf={self.trainable_rbf}, " |
| f"activation={self.activation}, " |
| f"attn_activation={self.attn_activation}, " |
| f"neighbor_embedding={self.neighbor_embedding}, " |
| f"num_heads={self.num_heads}, " |
| f"distance_influence={self.distance_influence}, " |
| f"cutoff_lower={self.cutoff_lower}, " |
| f"cutoff_upper={self.cutoff_upper})" |
| ) |
|
|
|
|
| class eqMSATriStarTransformer(nn.Module): |
| """The equivariant Transformer architecture. Edge attributes are MSA weights. |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnormunlim", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=False, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=True, |
| triangular_update=True, |
| ee_channels=None, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqMSATriStarTransformer, self).__init__() |
|
|
| assert distance_influence in ["keys", "values", "both", "none"] |
| assert rbf_type in rbf_class_mapping, ( |
| f'Unknown RBF type "{rbf_type}". ' |
| f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
| ) |
| assert activation in act_class_mapping, ( |
| f'Unknown activation function "{activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
| assert attn_activation in act_class_mapping, ( |
| f'Unknown attention activation function "{attn_activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
|
|
| self.x_in_channels = x_in_channels |
| self.x_channels = x_channels |
| self.vec_in_channels = vec_in_channels |
| self.vec_channels = vec_channels |
| self.x_hidden_channels = x_hidden_channels |
| self.vec_hidden_channels = vec_hidden_channels |
| self.num_layers = num_layers |
| self.num_rbf = num_rbf |
| self.num_edge_attr = num_edge_attr |
| self.rbf_type = rbf_type |
| self.trainable_rbf = trainable_rbf |
| self.activation = activation |
| self.attn_activation = attn_activation |
| self.neighbor_embedding = neighbor_embedding |
| self.num_heads = num_heads |
| self.distance_influence = distance_influence |
| self.cutoff_lower = cutoff_lower |
| self.cutoff_upper = cutoff_upper |
| self.triangular_update = triangular_update |
|
|
| self.distance = DistanceV2( |
| return_vecs=True, |
| loop=True, |
| ) |
| self.distance_expansion = rbf_class_mapping[rbf_type]( |
| cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
| ) |
| self.msa_encoder = MSAEncoder( |
| num_species=199, |
| weighting_schema='spe', |
| pairwise_type='cov', |
| ) if x_use_msa else None |
|
|
| self.node_x_proj = None |
| if x_in_channels is not None: |
| if x_in_embedding_type == "Linear": |
| self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
| elif x_in_embedding_type == "Linear_gelu": |
| self.node_x_proj = nn.Sequential( |
| nn.Linear(x_in_channels, x_channels), |
| nn.GELU(), |
| ) |
| else: |
| nn.Embedding(x_in_channels, x_channels) |
| self.ee_channels = ee_channels |
| self.attention_layers = nn.ModuleList() |
| self._set_attn_layers() |
| self.drop = nn.Dropout(drop_out_rate) |
| self.out_norm = nn.LayerNorm(x_channels) |
|
|
| self.reset_parameters() |
|
|
| def _set_attn_layers(self): |
| for _ in range(self.num_layers): |
| layer = EquivariantTriAngularMultiHeadAttention( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_in_channels, |
| vec_hidden_channels=self.vec_channels, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| ee_channels=self.ee_channels, |
| triangular_update=self.triangular_update, |
| ) |
| self.attention_layers.append(layer) |
|
|
| def reset_parameters(self): |
| self.distance_expansion.reset_parameters() |
| for attn in self.attention_layers: |
| attn.reset_parameters() |
| self.out_norm.reset_parameters() |
|
|
| def forward( |
| self, |
| x: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| coords = node_vec_attr + pos.unsqueeze(2) |
| |
| edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
| |
| if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
| if self.node_x_proj is not None: |
| x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
| else: |
| x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
| else: |
| x_msa = None |
| |
| |
| |
| |
| |
| if self.msa_encoder is not None and x_msa is not None: |
| _, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
| edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
| |
| |
| |
| |
| |
| del edge_attr |
| edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
| |
| |
| mask = edge_index_star[0] != edge_index_star[1] |
| edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
| del mask, edge_weight_star |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
|
|
| attn_weight_layers = [] |
| for i, attn in enumerate(self.attention_layers): |
| if i == 0: |
| dx, edge_attr_star, attn_weight = attn( |
| x, coords, edge_index_star, edge_attr_star, edge_vec_star) |
| else: |
| dx = 0 |
| x = x + self.drop(dx) |
| if return_attn: |
| attn_weight_layers.append(attn_weight) |
| x = self.out_norm(x) |
| return x, None, pos, edge_attr_star, batch, attn_weight_layers |
|
|
| def __repr__(self): |
| return ( |
| f"{self.__class__.__name__}(" |
| f"x_channels={self.x_channels}, " |
| f"x_hidden_channels={self.x_hidden_channels}, " |
| f"vec_in_channels={self.vec_in_channels}, " |
| f"vec_channels={self.vec_channels}, " |
| f"vec_hidden_channels={self.vec_hidden_channels}, " |
| f"num_layers={self.num_layers}, " |
| f"num_rbf={self.num_rbf}, " |
| f"rbf_type={self.rbf_type}, " |
| f"trainable_rbf={self.trainable_rbf}, " |
| f"activation={self.activation}, " |
| f"attn_activation={self.attn_activation}, " |
| f"neighbor_embedding={self.neighbor_embedding}, " |
| f"num_heads={self.num_heads}, " |
| f"distance_influence={self.distance_influence}, " |
| f"cutoff_lower={self.cutoff_lower}, " |
| f"cutoff_upper={self.cutoff_upper})" |
| ) |
|
|
|
|
| class eqMSATriStarGRUTransformer(nn.Module): |
| """The equivariant Transformer architecture. Edge attributes are MSA weights. |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnormunlim", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=False, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=True, |
| triangular_update=True, |
| ee_channels=None, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqMSATriStarGRUTransformer, self).__init__() |
|
|
| assert distance_influence in ["keys", "values", "both", "none"] |
| assert rbf_type in rbf_class_mapping, ( |
| f'Unknown RBF type "{rbf_type}". ' |
| f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
| ) |
| assert activation in act_class_mapping, ( |
| f'Unknown activation function "{activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
| assert attn_activation in act_class_mapping, ( |
| f'Unknown attention activation function "{attn_activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
|
|
| self.x_in_channels = x_in_channels |
| self.x_channels = x_channels |
| self.vec_in_channels = vec_in_channels |
| self.vec_channels = vec_channels |
| self.x_hidden_channels = x_hidden_channels |
| self.vec_hidden_channels = vec_hidden_channels |
| self.num_layers = num_layers |
| self.num_rbf = num_rbf |
| self.num_edge_attr = num_edge_attr |
| self.rbf_type = rbf_type |
| self.trainable_rbf = trainable_rbf |
| self.activation = activation |
| self.attn_activation = attn_activation |
| self.neighbor_embedding = neighbor_embedding |
| self.num_heads = num_heads |
| self.distance_influence = distance_influence |
| self.cutoff_lower = cutoff_lower |
| self.cutoff_upper = cutoff_upper |
| self.triangular_update = triangular_update |
|
|
| self.distance = DistanceV2( |
| return_vecs=True, |
| loop=True, |
| ) |
| self.distance_expansion = rbf_class_mapping[rbf_type]( |
| cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
| ) |
| self.msa_encoder = MSAEncoder( |
| num_species=199, |
| weighting_schema='spe', |
| pairwise_type='cov', |
| ) if x_use_msa else None |
|
|
| self.node_x_proj = None |
| if x_in_channels is not None: |
| if x_in_embedding_type == "Linear": |
| self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
| elif x_in_embedding_type == "Linear_gelu": |
| self.node_x_proj = nn.Sequential( |
| nn.Linear(x_in_channels, x_channels), |
| nn.GELU(), |
| ) |
| else: |
| nn.Embedding(x_in_channels, x_channels) |
| self.ee_channels = ee_channels |
| self.attention_layers = nn.ModuleList() |
| self._set_attn_layers() |
| self.drop = nn.Dropout(drop_out_rate) |
| |
|
|
| self.reset_parameters() |
|
|
| def _set_attn_layers(self): |
| for _ in range(self.num_layers): |
| layer = EquivariantTriAngularStarMultiHeadAttention( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_in_channels, |
| vec_hidden_channels=self.vec_channels, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| cutoff_lower=self.cutoff_lower, |
| cutoff_upper=self.cutoff_upper, |
| ee_channels=self.ee_channels, |
| triangular_update=self.triangular_update, |
| ) |
| self.attention_layers.append(layer) |
|
|
| def reset_parameters(self): |
| self.distance_expansion.reset_parameters() |
| for attn in self.attention_layers: |
| attn.reset_parameters() |
| |
|
|
| def forward( |
| self, |
| x: Tensor, |
| x_center: Tensor, |
| x_mask: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| coords = node_vec_attr + pos.unsqueeze(2) |
| |
| edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
| |
| if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
| if self.node_x_proj is not None: |
| x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
| else: |
| x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
| else: |
| x_msa = None |
| |
| |
| |
| |
| |
| if self.msa_encoder is not None and x_msa is not None: |
| _, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
| edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
| |
| |
| |
| |
| |
| del edge_attr |
| edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
| |
| |
| mask = edge_index_star[0] != edge_index_star[1] |
| edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
| del mask, edge_weight_star |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
| x = x * x_mask.unsqueeze(1) + x_center * (~x_mask).unsqueeze(1) |
|
|
| attn_weight_layers = [] |
| for _, attn in enumerate(self.attention_layers): |
| x, edge_attr_star, attn_weight = attn( |
| x, coords, edge_index_star, edge_attr_star, edge_vec_star) |
| if return_attn: |
| attn_weight_layers.append(attn_weight) |
| x = self.drop(x) |
| |
| batch = batch[~x_mask] |
| return x, None, pos, edge_attr_star, batch, attn_weight_layers |
|
|
| def __repr__(self): |
| return ( |
| f"{self.__class__.__name__}(" |
| f"x_channels={self.x_channels}, " |
| f"x_hidden_channels={self.x_hidden_channels}, " |
| f"vec_in_channels={self.vec_in_channels}, " |
| f"vec_channels={self.vec_channels}, " |
| f"vec_hidden_channels={self.vec_hidden_channels}, " |
| f"num_layers={self.num_layers}, " |
| f"num_rbf={self.num_rbf}, " |
| f"rbf_type={self.rbf_type}, " |
| f"trainable_rbf={self.trainable_rbf}, " |
| f"activation={self.activation}, " |
| f"attn_activation={self.attn_activation}, " |
| f"neighbor_embedding={self.neighbor_embedding}, " |
| f"num_heads={self.num_heads}, " |
| f"distance_influence={self.distance_influence}, " |
| f"cutoff_lower={self.cutoff_lower}, " |
| f"cutoff_upper={self.cutoff_upper})" |
| ) |
|
|
|
|
| class eqMSATriStarDropTransformer(nn.Module): |
| """The equivariant Transformer architecture. Edge attributes are MSA weights, distances and drop out is applied. |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnormunlim", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=False, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=True, |
| triangular_update=True, |
| ee_channels=None, |
| drop_out_rate=0, |
| use_lora=None, |
| layer_norm=True, |
| ): |
| super(eqMSATriStarDropTransformer, self).__init__() |
|
|
| assert distance_influence in ["keys", "values", "both", "none"] |
| assert rbf_type in rbf_class_mapping, ( |
| f'Unknown RBF type "{rbf_type}". ' |
| f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
| ) |
| assert activation in act_class_mapping, ( |
| f'Unknown activation function "{activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
| assert attn_activation in act_class_mapping, ( |
| f'Unknown attention activation function "{attn_activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
|
|
| self.x_in_channels = x_in_channels |
| self.x_channels = x_channels |
| self.vec_in_channels = vec_in_channels |
| self.vec_channels = vec_channels |
| self.x_hidden_channels = x_hidden_channels |
| self.vec_hidden_channels = vec_hidden_channels |
| self.num_layers = num_layers |
| self.num_rbf = num_rbf |
| self.num_edge_attr = num_edge_attr |
| self.rbf_type = rbf_type |
| self.trainable_rbf = trainable_rbf |
| self.activation = activation |
| self.attn_activation = attn_activation |
| self.neighbor_embedding = neighbor_embedding |
| self.num_heads = num_heads |
| self.distance_influence = distance_influence |
| self.cutoff_lower = cutoff_lower |
| self.cutoff_upper = cutoff_upper |
| self.triangular_update = triangular_update |
| self.use_lora = use_lora |
| self.layer_norm = layer_norm |
|
|
| self.distance = DistanceV2( |
| return_vecs=True, |
| loop=True, |
| ) |
| self.distance_expansion = rbf_class_mapping[rbf_type]( |
| cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
| ) |
| self.msa_encoder = MSAEncoder( |
| num_species=199, |
| weighting_schema='spe', |
| pairwise_type='cov', |
| ) if x_use_msa else None |
|
|
| self.node_x_proj = None |
| if x_in_channels is not None: |
| if x_in_embedding_type == "Linear": |
| if use_lora is not None: |
| self.node_x_proj = lora.Linear(x_in_channels, x_channels, r=use_lora) |
| else: |
| self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
| elif x_in_embedding_type == "Linear_gelu": |
| self.node_x_proj = nn.Sequential( |
| lora.Linear(x_in_channels, x_channels, r=use_lora) if use_lora is not None else nn.Linear(x_in_channels, x_channels), |
| nn.GELU(), |
| ) |
| else: |
| nn.Embedding(x_in_channels, x_channels) if use_lora is None else lora.Embedding(x_in_channels, x_channels, r=use_lora) |
| self.ee_channels = ee_channels |
| self.attention_layers = nn.ModuleList() |
| |
| self.drop_out_rate = drop_out_rate |
| self._set_attn_layers() |
| |
|
|
| self.reset_parameters() |
|
|
| def _set_attn_layers(self): |
| for _ in range(self.num_layers): |
| layer = EquivariantTriAngularDropMultiHeadAttention( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_in_channels, |
| vec_hidden_channels=self.vec_channels, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| ee_channels=self.ee_channels, |
| rbf_channels=self.num_rbf, |
| triangular_update=self.triangular_update, |
| drop_out_rate=self.drop_out_rate, |
| use_lora=self.use_lora, |
| layer_norm=self.layer_norm, |
| ) |
| self.attention_layers.append(layer) |
|
|
| def reset_parameters(self): |
| self.distance_expansion.reset_parameters() |
| for attn in self.attention_layers: |
| attn.reset_parameters() |
| |
|
|
| def forward( |
| self, |
| x: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| coords = node_vec_attr + pos.unsqueeze(2) |
| |
| edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
| |
| if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
| if self.node_x_proj is not None: |
| x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
| else: |
| x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
| else: |
| x_msa = None |
| |
| |
| |
| |
| |
| if self.msa_encoder is not None and x_msa is not None: |
| _, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
| edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
| |
| |
| |
| |
| |
| del edge_attr |
| edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
| |
| |
| mask = edge_index_star[0] != edge_index_star[1] |
| edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
| del mask, edge_weight_star |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
| |
|
|
| attn_weight_layers = [] |
| for _, attn in enumerate(self.attention_layers): |
| x, edge_attr_star, attn_weight = attn( |
| x, coords, edge_index_star, edge_attr_star, edge_vec_star) |
| if return_attn: |
| attn_weight_layers.append(attn_weight) |
| |
| |
| |
| return x, None, pos, edge_attr_star, batch, attn_weight_layers |
|
|
| def __repr__(self): |
| return ( |
| f"{self.__class__.__name__}(" |
| f"x_channels={self.x_channels}, " |
| f"x_hidden_channels={self.x_hidden_channels}, " |
| f"vec_in_channels={self.vec_in_channels}, " |
| f"vec_channels={self.vec_channels}, " |
| f"vec_hidden_channels={self.vec_hidden_channels}, " |
| f"num_layers={self.num_layers}, " |
| f"num_rbf={self.num_rbf}, " |
| f"rbf_type={self.rbf_type}, " |
| f"trainable_rbf={self.trainable_rbf}, " |
| f"activation={self.activation}, " |
| f"attn_activation={self.attn_activation}, " |
| f"neighbor_embedding={self.neighbor_embedding}, " |
| f"num_heads={self.num_heads}, " |
| f"distance_influence={self.distance_influence}, " |
| f"cutoff_lower={self.cutoff_lower}, " |
| f"cutoff_upper={self.cutoff_upper})" |
| ) |
|
|
|
|
| class eqMSATriStarDropGRUTransformer(nn.Module): |
| """The equivariant Transformer architecture. Edge attributes are MSA weights, distances and drop out is applied. |
| |
| Args: |
| x_channels (int, optional): Hidden embedding size. |
| (default: :obj:`128`) |
| num_layers (int, optional): The number of attention layers. |
| (default: :obj:`6`) |
| num_rbf (int, optional): The number of radial basis functions :math:`\mu`. |
| (default: :obj:`50`) |
| rbf_type (string, optional): The type of radial basis function to use. |
| (default: :obj:`"expnorm"`) |
| trainable_rbf (bool, optional): Whether to train RBF parameters with |
| backpropagation. (default: :obj:`True`) |
| activation (string, optional): The type of activation function to use. |
| (default: :obj:`"silu"`) |
| attn_activation (string, optional): The type of activation function to use |
| inside the attention mechanism. (default: :obj:`"silu"`) |
| neighbor_embedding (bool, optional): Whether to perform an initial neighbor |
| embedding step. (default: :obj:`True`) |
| num_heads (int, optional): Number of attention heads. |
| (default: :obj:`8`) |
| distance_influence (string, optional): Where distance information is used inside |
| the attention mechanism. (default: :obj:`"both"`) |
| cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions. |
| (default: :obj:`0.0`) |
| cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions. |
| (default: :obj:`5.0`) |
| """ |
|
|
| def __init__( |
| self, |
| x_in_channels=None, |
| x_channels=5120, |
| x_hidden_channels=1280, |
| vec_in_channels=4, |
| vec_channels=128, |
| vec_hidden_channels=5120, |
| num_layers=6, |
| num_edge_attr=145, |
| num_rbf=50, |
| rbf_type="expnormunlim", |
| trainable_rbf=True, |
| activation="silu", |
| attn_activation="silu", |
| neighbor_embedding=False, |
| num_heads=8, |
| distance_influence="both", |
| cutoff_lower=0.0, |
| cutoff_upper=5.0, |
| x_in_embedding_type="Linear", |
| x_use_msa=True, |
| triangular_update=True, |
| ee_channels=None, |
| drop_out_rate=0, |
| use_lora=None, |
| ): |
| super(eqMSATriStarDropGRUTransformer, self).__init__() |
|
|
| assert distance_influence in ["keys", "values", "both", "none"] |
| assert rbf_type in rbf_class_mapping, ( |
| f'Unknown RBF type "{rbf_type}". ' |
| f'Choose from {", ".join(rbf_class_mapping.keys())}.' |
| ) |
| assert activation in act_class_mapping, ( |
| f'Unknown activation function "{activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
| assert attn_activation in act_class_mapping, ( |
| f'Unknown attention activation function "{attn_activation}". ' |
| f'Choose from {", ".join(act_class_mapping.keys())}.' |
| ) |
|
|
| self.x_in_channels = x_in_channels |
| self.x_channels = x_channels |
| self.vec_in_channels = vec_in_channels |
| self.vec_channels = vec_channels |
| self.x_hidden_channels = x_hidden_channels |
| self.vec_hidden_channels = vec_hidden_channels |
| self.num_layers = num_layers |
| self.num_rbf = num_rbf |
| self.num_edge_attr = num_edge_attr |
| self.rbf_type = rbf_type |
| self.trainable_rbf = trainable_rbf |
| self.activation = activation |
| self.attn_activation = attn_activation |
| self.neighbor_embedding = neighbor_embedding |
| self.num_heads = num_heads |
| self.distance_influence = distance_influence |
| self.cutoff_lower = cutoff_lower |
| self.cutoff_upper = cutoff_upper |
| self.triangular_update = triangular_update |
| self.use_lora = use_lora |
|
|
| self.distance = DistanceV2( |
| return_vecs=True, |
| loop=True, |
| ) |
| self.distance_expansion = rbf_class_mapping[rbf_type]( |
| cutoff_lower, cutoff_upper, num_rbf, trainable_rbf |
| ) |
| self.msa_encoder = MSAEncoder( |
| num_species=199, |
| weighting_schema='spe', |
| pairwise_type='cov', |
| ) if x_use_msa else None |
|
|
| self.node_x_proj = None |
| if x_in_channels is not None: |
| if x_in_embedding_type == "Linear": |
| if use_lora is not None: |
| self.node_x_proj = lora.Linear(x_in_channels, x_channels, r=use_lora) |
| else: |
| self.node_x_proj = nn.Linear(x_in_channels, x_channels) |
| elif x_in_embedding_type == "Linear_gelu": |
| self.node_x_proj = nn.Sequential( |
| lora.Linear(x_in_channels, x_channels, r=use_lora) if use_lora is not None else nn.Linear(x_in_channels, x_channels), |
| nn.GELU(), |
| ) |
| else: |
| nn.Embedding(x_in_channels, x_channels) if use_lora is None else lora.Embedding(x_in_channels, x_channels, r=use_lora) |
| self.ee_channels = ee_channels |
| self.attention_layers = nn.ModuleList() |
| |
| self.drop_out_rate = drop_out_rate |
| self._set_attn_layers() |
| |
|
|
| self.reset_parameters() |
|
|
| def _set_attn_layers(self): |
| for _ in range(self.num_layers): |
| layer = EquivariantTriAngularStarDropMultiHeadAttention( |
| x_channels=self.x_channels, |
| x_hidden_channels=self.x_hidden_channels, |
| vec_channels=self.vec_in_channels, |
| vec_hidden_channels=self.vec_channels, |
| edge_attr_channels=self.num_rbf + self.num_edge_attr, |
| distance_influence=self.distance_influence, |
| num_heads=self.num_heads, |
| activation=act_class_mapping[self.activation], |
| attn_activation=self.attn_activation, |
| ee_channels=self.ee_channels, |
| rbf_channels=self.num_rbf, |
| triangular_update=self.triangular_update, |
| drop_out_rate=self.drop_out_rate, |
| use_lora=self.use_lora, |
| ) |
| self.attention_layers.append(layer) |
|
|
| def reset_parameters(self): |
| self.distance_expansion.reset_parameters() |
| for attn in self.attention_layers: |
| attn.reset_parameters() |
| |
|
|
| def forward( |
| self, |
| x: Tensor, |
| x_center: Tensor, |
| x_mask: Tensor, |
| pos: Tensor, |
| batch: Tensor, |
| edge_index: Tensor, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, List]: |
| coords = node_vec_attr + pos.unsqueeze(2) |
| |
| edge_index_star, edge_weight_star, edge_vec_star = self.distance(pos, coords, edge_index_star) |
| |
| if (self.x_in_channels is not None and x.shape[1] > self.x_in_channels) or x.shape[1] > self.x_channels: |
| if self.node_x_proj is not None: |
| x, x_msa = x[:, :self.x_in_channels], x[:, self.x_in_channels:] |
| else: |
| x, x_msa = x[:, :self.x_channels], x[:, self.x_channels:] |
| else: |
| x_msa = None |
| |
| |
| |
| |
| |
| if self.msa_encoder is not None and x_msa is not None: |
| _, msa_edge_attr_star = self.msa_encoder(x_msa, edge_index_star) |
| edge_attr_star = torch.cat([edge_attr_star, msa_edge_attr_star], dim=-1) |
| |
| |
| |
| |
| |
| del edge_attr |
| edge_attr_star = torch.cat([edge_attr_star, self.distance_expansion(edge_weight_star)], dim=-1) |
| |
| |
| mask = edge_index_star[0] != edge_index_star[1] |
| edge_vec_star[mask] = edge_vec_star[mask] / torch.norm(edge_vec_star[mask], dim=1).unsqueeze(1) |
| del mask, edge_weight_star |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
| x = x * x_mask.unsqueeze(1) + x_center * (~x_mask).unsqueeze(1) |
|
|
| attn_weight_layers = [] |
| for _, attn in enumerate(self.attention_layers): |
| x, edge_attr_star, attn_weight = attn( |
| x, coords, edge_index_star, edge_attr_star, edge_vec_star) |
| if return_attn: |
| attn_weight_layers.append(attn_weight) |
| |
| |
| batch = batch[~x_mask] |
| return x, None, pos, edge_attr_star, batch, attn_weight_layers |
|
|
| def __repr__(self): |
| return ( |
| f"{self.__class__.__name__}(" |
| f"x_channels={self.x_channels}, " |
| f"x_hidden_channels={self.x_hidden_channels}, " |
| f"vec_in_channels={self.vec_in_channels}, " |
| f"vec_channels={self.vec_channels}, " |
| f"vec_hidden_channels={self.vec_hidden_channels}, " |
| f"num_layers={self.num_layers}, " |
| f"num_rbf={self.num_rbf}, " |
| f"rbf_type={self.rbf_type}, " |
| f"trainable_rbf={self.trainable_rbf}, " |
| f"activation={self.activation}, " |
| f"attn_activation={self.attn_activation}, " |
| f"neighbor_embedding={self.neighbor_embedding}, " |
| f"num_heads={self.num_heads}, " |
| f"distance_influence={self.distance_influence}, " |
| f"cutoff_lower={self.cutoff_lower}, " |
| f"cutoff_upper={self.cutoff_upper})" |
| ) |
|
|
|
|
| |
| class eqTriAttnTransformer(nn.Module): |
| """ |
| Input a sequence representation and structure, output a new sequence representation and structure |
| """ |
|
|
| def __init__(self, |
| x_in_channels=None, |
| x_channels=1280, |
| pairwise_state_dim=128, |
| num_layers=4, |
| num_heads=8, |
| x_in_embedding_type="Embedding", |
| drop_out_rate=0.1, |
| x_hidden_channels=None, |
| vec_channels=None, |
| vec_in_channels=None, |
| vec_hidden_channels=None, |
| num_edge_attr=None, |
| num_rbf=None, |
| rbf_type=None, |
| trainable_rbf=None, |
| activation=None, |
| neighbor_embedding=None, |
| cutoff_lower=None, |
| cutoff_upper=None, |
| x_use_msa=False, |
| use_lora=None, |
| ): |
| super(eqTriAttnTransformer, self).__init__() |
| if x_in_channels is not None: |
| self.node_x_proj = nn.Linear(x_in_channels, x_channels) if x_in_embedding_type == "Linear" \ |
| else nn.Embedding(x_in_channels, x_channels) |
| else: |
| self.node_x_proj = None |
| assert x_channels % num_heads == 0 \ |
| and pairwise_state_dim % num_heads == 0, ( |
| f"The number of hidden channels x_channels ({x_channels}) " |
| f"and pair-wise channels ({pairwise_state_dim}) " |
| f"must be evenly divisible by the number of " |
| f"attention heads ({num_heads})" |
| ) |
| sequence_head_width = x_channels // num_heads |
| pairwise_head_width = pairwise_state_dim // num_heads |
| self.tri_attn_block = nn.ModuleList( |
| [ |
| TriangularSelfAttentionBlock( |
| sequence_state_dim=x_channels, |
| pairwise_state_dim=pairwise_state_dim, |
| sequence_head_width=sequence_head_width, |
| pairwise_head_width=pairwise_head_width, |
| dropout=drop_out_rate, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.seq_struct_to_pair = PairFeatureNet( |
| x_channels, pairwise_state_dim) |
| |
| |
| self.seq_pair_to_output = SeqPairAttentionOutput(seq_state_dim=x_channels, |
| pairwise_state_dim=pairwise_state_dim, |
| num_heads=num_heads, |
| output_dim=x_channels, |
| dropout=drop_out_rate) |
|
|
| def reset_parameters(self): |
| pass |
|
|
| def forward(self, |
| x: Tensor, |
| pos: Tensor, |
| residx: Tensor = None, |
| mask: Tensor = None, |
| batch: Tensor = None, |
| edge_index: Tensor = None, |
| edge_index_star: Tensor = None, |
| edge_attr: Tensor = None, |
| edge_attr_star: Tensor = None, |
| node_vec_attr: Tensor = None, |
| return_attn: bool = False, |
| ): |
| """ |
| Inputs: |
| x: B x L x C tensor of sequence features |
| pos: B x L x 4 x 3 tensor of [CA, CB, N, O] coordinates |
| residx: B x L long tensor giving the position in the sequence |
| mask: B x L boolean tensor indicating valid residues |
| |
| Output: |
| predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object |
| """ |
|
|
| if residx is None: |
| residx = torch.arange( |
| x.shape[1], device=x.device).repeat(x.shape[0], 1) |
| if mask is None: |
| mask = torch.ones((x.shape[0], x.shape[1]), |
| dtype=torch.bool, device=x.device) |
| |
| x = self.node_x_proj(x) if self.node_x_proj is not None else x |
| |
| pair_feats = self.seq_struct_to_pair(x, pos, residx, mask) |
|
|
| s_s = x |
| s_z = pair_feats |
|
|
| for block in self.tri_attn_block: |
| s_s, s_z = block(sequence_state=s_s, |
| pairwise_state=s_z, |
| mask=mask.to(torch.float32)) |
|
|
| s_s = self.seq_pair_to_output( |
| sequence_state=s_s, pairwise_state=s_z, mask=mask.to(torch.float32)) |
| |
| |
| |
| return s_s, s_z, pos, None, None, None |
|
|