| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from functools import reduce |
| | import importlib |
| | import math |
| | import sys |
| | from operator import mul |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from typing import Optional, Tuple, Sequence |
| |
|
| | from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_ |
| | from openfold.np.residue_constants import ( |
| | restype_rigid_group_default_frame, |
| | restype_atom14_to_rigid_group, |
| | restype_atom14_mask, |
| | restype_atom14_rigid_group_positions, |
| | ) |
| | from openfold.utils.feats import ( |
| | frames_and_literature_positions_to_atom14_pos, |
| | torsion_angles_to_frames, |
| | ) |
| | from openfold.utils.precision_utils import is_fp16_enabled |
| | from openfold.utils.rigid_utils import Rotation, Rigid |
| | from openfold.utils.tensor_utils import ( |
| | dict_multimap, |
| | permute_final_dims, |
| | flatten_final_dims, |
| | ) |
| |
|
| | |
| |
|
| |
|
| | class AngleResnetBlock(nn.Module): |
| | def __init__(self, c_hidden): |
| | """ |
| | Args: |
| | c_hidden: |
| | Hidden channel dimension |
| | """ |
| | super(AngleResnetBlock, self).__init__() |
| |
|
| | self.c_hidden = c_hidden |
| |
|
| | self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu") |
| | self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final") |
| |
|
| | self.relu = nn.ReLU() |
| |
|
| | def forward(self, a: torch.Tensor) -> torch.Tensor: |
| |
|
| | s_initial = a |
| |
|
| | a = self.relu(a) |
| | a = self.linear_1(a) |
| | a = self.relu(a) |
| | a = self.linear_2(a) |
| |
|
| | return a + s_initial |
| |
|
| |
|
| | class AngleResnet(nn.Module): |
| | """ |
| | Implements Algorithm 20, lines 11-14 |
| | """ |
| |
|
| | def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon): |
| | """ |
| | Args: |
| | c_in: |
| | Input channel dimension |
| | c_hidden: |
| | Hidden channel dimension |
| | no_blocks: |
| | Number of resnet blocks |
| | no_angles: |
| | Number of torsion angles to generate |
| | epsilon: |
| | Small constant for normalization |
| | """ |
| | super(AngleResnet, self).__init__() |
| |
|
| | self.c_in = c_in |
| | self.c_hidden = c_hidden |
| | self.no_blocks = no_blocks |
| | self.no_angles = no_angles |
| | self.eps = epsilon |
| |
|
| | self.linear_in = Linear(self.c_in, self.c_hidden) |
| | self.linear_initial = Linear(self.c_in, self.c_hidden) |
| |
|
| | self.layers = nn.ModuleList() |
| | for _ in range(self.no_blocks): |
| | layer = AngleResnetBlock(c_hidden=self.c_hidden) |
| | self.layers.append(layer) |
| |
|
| | self.linear_out = Linear(self.c_hidden, self.no_angles * 2) |
| |
|
| | self.relu = nn.ReLU() |
| |
|
| | def forward( |
| | self, s: torch.Tensor, s_initial: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | s: |
| | [*, C_hidden] single embedding |
| | s_initial: |
| | [*, C_hidden] single embedding as of the start of the |
| | StructureModule |
| | Returns: |
| | [*, no_angles, 2] predicted angles |
| | """ |
| | |
| | |
| | |
| |
|
| | |
| | s_initial = self.relu(s_initial) |
| | s_initial = self.linear_initial(s_initial) |
| | s = self.relu(s) |
| | s = self.linear_in(s) |
| | s = s + s_initial |
| |
|
| | for l in self.layers: |
| | s = l(s) |
| |
|
| | s = self.relu(s) |
| |
|
| | |
| | s = self.linear_out(s) |
| |
|
| | |
| | s = s.view(s.shape[:-1] + (-1, 2)) |
| |
|
| | unnormalized_s = s |
| | norm_denom = torch.sqrt( |
| | torch.clamp( |
| | torch.sum(s ** 2, dim=-1, keepdim=True), |
| | min=self.eps, |
| | ) |
| | ) |
| | s = s / norm_denom |
| |
|
| | return unnormalized_s, s |
| |
|
| |
|
| | class InvariantPointAttention(nn.Module): |
| | """ |
| | Implements Algorithm 22. |
| | """ |
| | def __init__( |
| | self, |
| | c_s: int, |
| | c_z: int, |
| | c_hidden: int, |
| | no_heads: int, |
| | no_qk_points: int, |
| | no_v_points: int, |
| | inf: float = 1e5, |
| | eps: float = 1e-8, |
| | ): |
| | """ |
| | Args: |
| | c_s: |
| | Single representation channel dimension |
| | c_z: |
| | Pair representation channel dimension |
| | c_hidden: |
| | Hidden channel dimension |
| | no_heads: |
| | Number of attention heads |
| | no_qk_points: |
| | Number of query/key points to generate |
| | no_v_points: |
| | Number of value points to generate |
| | """ |
| | super(InvariantPointAttention, self).__init__() |
| |
|
| | self.c_s = c_s |
| | self.c_z = c_z |
| | self.c_hidden = c_hidden |
| | self.no_heads = no_heads |
| | self.no_qk_points = no_qk_points |
| | self.no_v_points = no_v_points |
| | self.inf = inf |
| | self.eps = eps |
| |
|
| | |
| | |
| | |
| | |
| | hc = self.c_hidden * self.no_heads |
| | self.linear_q = Linear(self.c_s, hc) |
| | self.linear_kv = Linear(self.c_s, 2 * hc) |
| |
|
| | hpq = self.no_heads * self.no_qk_points * 3 |
| | self.linear_q_points = Linear(self.c_s, hpq) |
| |
|
| | hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3 |
| | self.linear_kv_points = Linear(self.c_s, hpkv) |
| |
|
| | hpv = self.no_heads * self.no_v_points * 3 |
| |
|
| | self.linear_b = Linear(self.c_z, self.no_heads) |
| |
|
| | self.head_weights = nn.Parameter(torch.zeros((no_heads))) |
| | ipa_point_weights_init_(self.head_weights) |
| |
|
| | concat_out_dim = self.no_heads * ( |
| | self.c_z + self.c_hidden + self.no_v_points * 4 |
| | ) |
| | self.linear_out = Linear(concat_out_dim, self.c_s, init="final") |
| |
|
| | self.softmax = nn.Softmax(dim=-1) |
| | self.softplus = nn.Softplus() |
| |
|
| | def forward( |
| | self, |
| | s: torch.Tensor, |
| | z: Optional[torch.Tensor], |
| | r: Rigid, |
| | mask: torch.Tensor, |
| | inplace_safe: bool = False, |
| | _offload_inference: bool = False, |
| | _z_reference_list: Optional[Sequence[torch.Tensor]] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Args: |
| | s: |
| | [*, N_res, C_s] single representation |
| | z: |
| | [*, N_res, N_res, C_z] pair representation |
| | r: |
| | [*, N_res] transformation object |
| | mask: |
| | [*, N_res] mask |
| | Returns: |
| | [*, N_res, C_s] single representation update |
| | """ |
| | if(_offload_inference and inplace_safe): |
| | z = _z_reference_list |
| | else: |
| | z = [z] |
| | |
| | |
| | |
| | |
| | |
| | q = self.linear_q(s) |
| | kv = self.linear_kv(s) |
| |
|
| | |
| | q = q.view(q.shape[:-1] + (self.no_heads, -1)) |
| |
|
| | |
| | kv = kv.view(kv.shape[:-1] + (self.no_heads, -1)) |
| |
|
| | |
| | k, v = torch.split(kv, self.c_hidden, dim=-1) |
| |
|
| | |
| | q_pts = self.linear_q_points(s) |
| |
|
| | |
| | |
| | q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1) |
| | q_pts = torch.stack(q_pts, dim=-1) |
| | q_pts = r[..., None].apply(q_pts) |
| |
|
| | |
| | q_pts = q_pts.view( |
| | q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3) |
| | ) |
| |
|
| | |
| | kv_pts = self.linear_kv_points(s) |
| |
|
| | |
| | kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) |
| | kv_pts = torch.stack(kv_pts, dim=-1) |
| | kv_pts = r[..., None].apply(kv_pts) |
| |
|
| | |
| | kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3)) |
| |
|
| | |
| | k_pts, v_pts = torch.split( |
| | kv_pts, [self.no_qk_points, self.no_v_points], dim=-2 |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | b = self.linear_b(z[0]) |
| | |
| | if(_offload_inference): |
| | assert(sys.getrefcount(z[0]) == 2) |
| | z[0] = z[0].cpu() |
| |
|
| | |
| | if(is_fp16_enabled()): |
| | with torch.cuda.amp.autocast(enabled=False): |
| | a = torch.matmul( |
| | permute_final_dims(q.float(), (1, 0, 2)), |
| | permute_final_dims(k.float(), (1, 2, 0)), |
| | ) |
| | else: |
| | a = torch.matmul( |
| | permute_final_dims(q, (1, 0, 2)), |
| | permute_final_dims(k, (1, 2, 0)), |
| | ) |
| | |
| | a *= math.sqrt(1.0 / (3 * self.c_hidden)) |
| | a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) |
| |
|
| | |
| | pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) |
| | if(inplace_safe): |
| | pt_att *= pt_att |
| | else: |
| | pt_att = pt_att ** 2 |
| |
|
| | |
| | pt_att = sum(torch.unbind(pt_att, dim=-1)) |
| | head_weights = self.softplus(self.head_weights).view( |
| | *((1,) * len(pt_att.shape[:-2]) + (-1, 1)) |
| | ) |
| | head_weights = head_weights * math.sqrt( |
| | 1.0 / (3 * (self.no_qk_points * 9.0 / 2)) |
| | ) |
| | if(inplace_safe): |
| | pt_att *= head_weights |
| | else: |
| | pt_att = pt_att * head_weights |
| |
|
| | |
| | pt_att = torch.sum(pt_att, dim=-1) * (-0.5) |
| | |
| | square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) |
| | square_mask = self.inf * (square_mask - 1) |
| |
|
| | |
| | pt_att = permute_final_dims(pt_att, (2, 0, 1)) |
| | |
| | if(inplace_safe): |
| | a += pt_att |
| | del pt_att |
| | a += square_mask.unsqueeze(-3) |
| | |
| | attn_core_inplace_cuda.forward_( |
| | a, |
| | reduce(mul, a.shape[:-1]), |
| | a.shape[-1], |
| | ) |
| | else: |
| | a = a + pt_att |
| | a = a + square_mask.unsqueeze(-3) |
| | a = self.softmax(a) |
| |
|
| | |
| | |
| | |
| | |
| | o = torch.matmul( |
| | a, v.transpose(-2, -3).to(dtype=a.dtype) |
| | ).transpose(-2, -3) |
| |
|
| | |
| | o = flatten_final_dims(o, 2) |
| |
|
| | |
| | if(inplace_safe): |
| | v_pts = permute_final_dims(v_pts, (1, 3, 0, 2)) |
| | o_pt = [ |
| | torch.matmul(a, v.to(a.dtype)) |
| | for v in torch.unbind(v_pts, dim=-3) |
| | ] |
| | o_pt = torch.stack(o_pt, dim=-3) |
| | else: |
| | o_pt = torch.sum( |
| | ( |
| | a[..., None, :, :, None] |
| | * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :] |
| | ), |
| | dim=-2, |
| | ) |
| |
|
| | |
| | o_pt = permute_final_dims(o_pt, (2, 0, 3, 1)) |
| | o_pt = r[..., None, None].invert_apply(o_pt) |
| |
|
| | |
| | o_pt_norm = flatten_final_dims( |
| | torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2 |
| | ) |
| |
|
| | |
| | o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3) |
| |
|
| | if(_offload_inference): |
| | z[0] = z[0].to(o_pt.device) |
| |
|
| | |
| | o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype)) |
| |
|
| | |
| | o_pair = flatten_final_dims(o_pair, 2) |
| |
|
| | |
| | s = self.linear_out( |
| | torch.cat( |
| | (o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1 |
| | ).to(dtype=z[0].dtype) |
| | ) |
| | |
| | return s |
| |
|
| |
|
| | class BackboneUpdate(nn.Module): |
| | """ |
| | Implements part of Algorithm 23. |
| | """ |
| |
|
| | def __init__(self, c_s): |
| | """ |
| | Args: |
| | c_s: |
| | Single representation channel dimension |
| | """ |
| | super(BackboneUpdate, self).__init__() |
| |
|
| | self.c_s = c_s |
| |
|
| | self.linear = Linear(self.c_s, 6, init="final") |
| |
|
| | def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | [*, N_res, C_s] single representation |
| | Returns: |
| | [*, N_res, 6] update vector |
| | """ |
| | |
| | update = self.linear(s) |
| |
|
| | return update |
| |
|
| |
|
| | class StructureModuleTransitionLayer(nn.Module): |
| | def __init__(self, c): |
| | super(StructureModuleTransitionLayer, self).__init__() |
| |
|
| | self.c = c |
| |
|
| | self.linear_1 = Linear(self.c, self.c, init="relu") |
| | self.linear_2 = Linear(self.c, self.c, init="relu") |
| | self.linear_3 = Linear(self.c, self.c, init="final") |
| |
|
| | self.relu = nn.ReLU() |
| |
|
| | def forward(self, s): |
| | s_initial = s |
| | s = self.linear_1(s) |
| | s = self.relu(s) |
| | s = self.linear_2(s) |
| | s = self.relu(s) |
| | s = self.linear_3(s) |
| |
|
| | s = s + s_initial |
| |
|
| | return s |
| |
|
| |
|
| | class StructureModuleTransition(nn.Module): |
| | def __init__(self, c, num_layers, dropout_rate): |
| | super(StructureModuleTransition, self).__init__() |
| |
|
| | self.c = c |
| | self.num_layers = num_layers |
| | self.dropout_rate = dropout_rate |
| |
|
| | self.layers = nn.ModuleList() |
| | for _ in range(self.num_layers): |
| | l = StructureModuleTransitionLayer(self.c) |
| | self.layers.append(l) |
| |
|
| | self.dropout = nn.Dropout(self.dropout_rate) |
| | self.layer_norm = LayerNorm(self.c) |
| |
|
| | def forward(self, s): |
| | for l in self.layers: |
| | s = l(s) |
| |
|
| | s = self.dropout(s) |
| | s = self.layer_norm(s) |
| |
|
| | return s |
| |
|
| |
|
| | class StructureModule(nn.Module): |
| | def __init__( |
| | self, |
| | c_s, |
| | c_z, |
| | c_ipa, |
| | c_resnet, |
| | no_heads_ipa, |
| | no_qk_points, |
| | no_v_points, |
| | dropout_rate, |
| | no_blocks, |
| | no_transition_layers, |
| | no_resnet_blocks, |
| | no_angles, |
| | trans_scale_factor, |
| | epsilon, |
| | inf, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | c_s: |
| | Single representation channel dimension |
| | c_z: |
| | Pair representation channel dimension |
| | c_ipa: |
| | IPA hidden channel dimension |
| | c_resnet: |
| | Angle resnet (Alg. 23 lines 11-14) hidden channel dimension |
| | no_heads_ipa: |
| | Number of IPA heads |
| | no_qk_points: |
| | Number of query/key points to generate during IPA |
| | no_v_points: |
| | Number of value points to generate during IPA |
| | dropout_rate: |
| | Dropout rate used throughout the layer |
| | no_blocks: |
| | Number of structure module blocks |
| | no_transition_layers: |
| | Number of layers in the single representation transition |
| | (Alg. 23 lines 8-9) |
| | no_resnet_blocks: |
| | Number of blocks in the angle resnet |
| | no_angles: |
| | Number of angles to generate in the angle resnet |
| | trans_scale_factor: |
| | Scale of single representation transition hidden dimension |
| | epsilon: |
| | Small number used in angle resnet normalization |
| | inf: |
| | Large number used for attention masking |
| | """ |
| | super(StructureModule, self).__init__() |
| |
|
| | self.c_s = c_s |
| | self.c_z = c_z |
| | self.c_ipa = c_ipa |
| | self.c_resnet = c_resnet |
| | self.no_heads_ipa = no_heads_ipa |
| | self.no_qk_points = no_qk_points |
| | self.no_v_points = no_v_points |
| | self.dropout_rate = dropout_rate |
| | self.no_blocks = no_blocks |
| | self.no_transition_layers = no_transition_layers |
| | self.no_resnet_blocks = no_resnet_blocks |
| | self.no_angles = no_angles |
| | self.trans_scale_factor = trans_scale_factor |
| | self.epsilon = epsilon |
| | self.inf = inf |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.layer_norm_s = LayerNorm(self.c_s) |
| | self.layer_norm_z = LayerNorm(self.c_z) |
| |
|
| | self.linear_in = Linear(self.c_s, self.c_s) |
| |
|
| | self.ipa = InvariantPointAttention( |
| | self.c_s, |
| | self.c_z, |
| | self.c_ipa, |
| | self.no_heads_ipa, |
| | self.no_qk_points, |
| | self.no_v_points, |
| | inf=self.inf, |
| | eps=self.epsilon, |
| | ) |
| |
|
| | self.ipa_dropout = nn.Dropout(self.dropout_rate) |
| | self.layer_norm_ipa = LayerNorm(self.c_s) |
| |
|
| | self.transition = StructureModuleTransition( |
| | self.c_s, |
| | self.no_transition_layers, |
| | self.dropout_rate, |
| | ) |
| |
|
| | self.bb_update = BackboneUpdate(self.c_s) |
| |
|
| | self.angle_resnet = AngleResnet( |
| | self.c_s, |
| | self.c_resnet, |
| | self.no_resnet_blocks, |
| | self.no_angles, |
| | self.epsilon, |
| | ) |
| |
|
| | def forward( |
| | self, |
| | evoformer_output_dict, |
| | aatype, |
| | mask=None, |
| | inplace_safe=False, |
| | _offload_inference=False, |
| | ): |
| | """ |
| | Args: |
| | evoformer_output_dict: |
| | Dictionary containing: |
| | "single": |
| | [*, N_res, C_s] single representation |
| | "pair": |
| | [*, N_res, N_res, C_z] pair representation |
| | aatype: |
| | [*, N_res] amino acid indices |
| | mask: |
| | Optional [*, N_res] sequence mask |
| | Returns: |
| | A dictionary of outputs |
| | """ |
| | s = evoformer_output_dict["single"] |
| | |
| | if mask is None: |
| | |
| | mask = s.new_ones(s.shape[:-1]) |
| |
|
| | |
| | s = self.layer_norm_s(s) |
| |
|
| | |
| | z = self.layer_norm_z(evoformer_output_dict["pair"]) |
| |
|
| | z_reference_list = None |
| | if(_offload_inference): |
| | assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2) |
| | evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu() |
| | z_reference_list = [z] |
| | z = None |
| |
|
| | |
| | s_initial = s |
| | s = self.linear_in(s) |
| |
|
| | |
| | rigids = Rigid.identity( |
| | s.shape[:-1], |
| | s.dtype, |
| | s.device, |
| | self.training, |
| | fmt="quat", |
| | ) |
| | outputs = [] |
| | for i in range(self.no_blocks): |
| | |
| | s = s + self.ipa( |
| | s, |
| | z, |
| | rigids, |
| | mask, |
| | inplace_safe=inplace_safe, |
| | _offload_inference=_offload_inference, |
| | _z_reference_list=z_reference_list |
| | ) |
| | s = self.ipa_dropout(s) |
| | s = self.layer_norm_ipa(s) |
| | s = self.transition(s) |
| | |
| | |
| | rigids = rigids.compose_q_update_vec(self.bb_update(s)) |
| |
|
| | |
| | |
| | |
| | backb_to_global = Rigid( |
| | Rotation( |
| | rot_mats=rigids.get_rots().get_rot_mats(), |
| | quats=None |
| | ), |
| | rigids.get_trans(), |
| | ) |
| |
|
| | backb_to_global = backb_to_global.scale_translation( |
| | self.trans_scale_factor |
| | ) |
| |
|
| | |
| | unnormalized_angles, angles = self.angle_resnet(s, s_initial) |
| |
|
| | all_frames_to_global = self.torsion_angles_to_frames( |
| | backb_to_global, |
| | angles, |
| | aatype, |
| | ) |
| |
|
| | pred_xyz = self.frames_and_literature_positions_to_atom14_pos( |
| | all_frames_to_global, |
| | aatype, |
| | ) |
| |
|
| | scaled_rigids = rigids.scale_translation(self.trans_scale_factor) |
| | |
| | preds = { |
| | "frames": scaled_rigids.to_tensor_7(), |
| | "sidechain_frames": all_frames_to_global.to_tensor_4x4(), |
| | "unnormalized_angles": unnormalized_angles, |
| | "angles": angles, |
| | "positions": pred_xyz, |
| | "states": s, |
| | } |
| |
|
| | outputs.append(preds) |
| |
|
| | rigids = rigids.stop_rot_gradient() |
| |
|
| | del z, z_reference_list |
| | |
| | if(_offload_inference): |
| | evoformer_output_dict["pair"] = ( |
| | evoformer_output_dict["pair"].to(s.device) |
| | ) |
| |
|
| | outputs = dict_multimap(torch.stack, outputs) |
| | outputs["single"] = s |
| |
|
| | return outputs |
| |
|
| | def _init_residue_constants(self, float_dtype, device): |
| | if not hasattr(self, "default_frames"): |
| | self.register_buffer( |
| | "default_frames", |
| | torch.tensor( |
| | restype_rigid_group_default_frame, |
| | dtype=float_dtype, |
| | device=device, |
| | requires_grad=False, |
| | ), |
| | persistent=False, |
| | ) |
| | if not hasattr(self, "group_idx"): |
| | self.register_buffer( |
| | "group_idx", |
| | torch.tensor( |
| | restype_atom14_to_rigid_group, |
| | device=device, |
| | requires_grad=False, |
| | ), |
| | persistent=False, |
| | ) |
| | if not hasattr(self, "atom_mask"): |
| | self.register_buffer( |
| | "atom_mask", |
| | torch.tensor( |
| | restype_atom14_mask, |
| | dtype=float_dtype, |
| | device=device, |
| | requires_grad=False, |
| | ), |
| | persistent=False, |
| | ) |
| | if not hasattr(self, "lit_positions"): |
| | self.register_buffer( |
| | "lit_positions", |
| | torch.tensor( |
| | restype_atom14_rigid_group_positions, |
| | dtype=float_dtype, |
| | device=device, |
| | requires_grad=False, |
| | ), |
| | persistent=False, |
| | ) |
| |
|
| | def torsion_angles_to_frames(self, r, alpha, f): |
| | |
| | self._init_residue_constants(alpha.dtype, alpha.device) |
| | |
| | return torsion_angles_to_frames(r, alpha, f, self.default_frames) |
| |
|
| | def frames_and_literature_positions_to_atom14_pos( |
| | self, r, f |
| | ): |
| | |
| | self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) |
| | return frames_and_literature_positions_to_atom14_pos( |
| | r, |
| | f, |
| | self.default_frames, |
| | self.group_idx, |
| | self.atom_mask, |
| | self.lit_positions, |
| | ) |
| |
|