| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from functools import partial |
| import torch |
| import torch.nn as nn |
|
|
| from openfold.utils.feats import ( |
| pseudo_beta_fn, |
| build_extra_msa_feat, |
| build_template_angle_feat, |
| build_template_pair_feat, |
| atom14_to_atom37, |
| ) |
| from openfold.model.embedders import ( |
| InputEmbedder, |
| RecyclingEmbedder, |
| TemplateAngleEmbedder, |
| TemplatePairEmbedder, |
| ExtraMSAEmbedder, |
| ) |
| from openfold.model.evoformer import EvoformerStack, ExtraMSAStack |
| from openfold.model.heads import AuxiliaryHeads |
| import openfold.np.residue_constants as residue_constants |
| from openfold.model.structure_module import StructureModule |
| from openfold.model.template import ( |
| TemplatePairStack, |
| TemplatePointwiseAttention, |
| ) |
| from openfold.utils.loss import ( |
| compute_plddt, |
| ) |
| from openfold.utils.tensor_utils import ( |
| dict_multimap, |
| tensor_tree_map, |
| ) |
|
|
|
|
| class AlphaFold(nn.Module): |
| """ |
| Alphafold 2. |
| |
| Implements Algorithm 2 (but with training). |
| """ |
|
|
| def __init__(self, config): |
| """ |
| Args: |
| config: |
| A dict-like config object (like the one in config.py) |
| """ |
| super(AlphaFold, self).__init__() |
|
|
| self.globals = config.globals |
| config = config.model |
| template_config = config.template |
| extra_msa_config = config.extra_msa |
|
|
| |
| self.input_embedder = InputEmbedder( |
| **config["input_embedder"], |
| ) |
| self.recycling_embedder = RecyclingEmbedder( |
| **config["recycling_embedder"], |
| ) |
| self.template_angle_embedder = TemplateAngleEmbedder( |
| **template_config["template_angle_embedder"], |
| ) |
| self.template_pair_embedder = TemplatePairEmbedder( |
| **template_config["template_pair_embedder"], |
| ) |
| self.template_pair_stack = TemplatePairStack( |
| **template_config["template_pair_stack"], |
| ) |
| self.template_pointwise_att = TemplatePointwiseAttention( |
| **template_config["template_pointwise_attention"], |
| ) |
| self.extra_msa_embedder = ExtraMSAEmbedder( |
| **extra_msa_config["extra_msa_embedder"], |
| ) |
| self.extra_msa_stack = ExtraMSAStack( |
| **extra_msa_config["extra_msa_stack"], |
| ) |
| self.evoformer = EvoformerStack( |
| **config["evoformer_stack"], |
| ) |
| self.structure_module = StructureModule( |
| **config["structure_module"], |
| ) |
|
|
| self.aux_heads = AuxiliaryHeads( |
| config["heads"], |
| ) |
|
|
| self.config = config |
|
|
| def embed_templates(self, batch, z, pair_mask, templ_dim): |
| |
| template_embeds = [] |
| n_templ = batch["template_aatype"].shape[templ_dim] |
| for i in range(n_templ): |
| idx = batch["template_aatype"].new_tensor(i) |
| single_template_feats = tensor_tree_map( |
| lambda t: torch.index_select(t, templ_dim, idx), |
| batch, |
| ) |
|
|
| single_template_embeds = {} |
| if self.config.template.embed_angles: |
| template_angle_feat = build_template_angle_feat( |
| single_template_feats, |
| ) |
|
|
| |
| a = self.template_angle_embedder(template_angle_feat) |
|
|
| single_template_embeds["angle"] = a |
|
|
| |
| t = build_template_pair_feat( |
| single_template_feats, |
| inf=self.config.template.inf, |
| eps=self.config.template.eps, |
| **self.config.template.distogram, |
| ).to(z.dtype) |
| t = self.template_pair_embedder(t) |
|
|
| single_template_embeds.update({"pair": t}) |
|
|
| template_embeds.append(single_template_embeds) |
|
|
| template_embeds = dict_multimap( |
| partial(torch.cat, dim=templ_dim), |
| template_embeds, |
| ) |
|
|
| |
| t = self.template_pair_stack( |
| template_embeds["pair"], |
| pair_mask.unsqueeze(-3).to(dtype=z.dtype), |
| chunk_size=self.globals.chunk_size, |
| _mask_trans=self.config._mask_trans, |
| ) |
|
|
| |
| t = self.template_pointwise_att( |
| t, |
| z, |
| template_mask=batch["template_mask"].to(dtype=z.dtype), |
| chunk_size=self.globals.chunk_size, |
| ) |
| t = t * (torch.sum(batch["template_mask"]) > 0) |
|
|
| ret = {} |
| if self.config.template.embed_angles: |
| ret["template_angle_embedding"] = template_embeds["angle"] |
|
|
| ret.update({"template_pair_embedding": t}) |
|
|
| return ret |
|
|
| def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True): |
| |
| outputs = {} |
|
|
| |
| dtype = next(self.parameters()).dtype |
| for k in feats: |
| if(feats[k].dtype == torch.float32): |
| feats[k] = feats[k].to(dtype=dtype) |
|
|
| |
| batch_dims = feats["target_feat"].shape[:-2] |
| no_batch_dims = len(batch_dims) |
| n = feats["target_feat"].shape[-2] |
| n_seq = feats["msa_feat"].shape[-3] |
| device = feats["target_feat"].device |
|
|
| |
| seq_mask = feats["seq_mask"] |
| pair_mask = seq_mask[..., None] * seq_mask[..., None, :] |
| msa_mask = feats["msa_mask"] |
|
|
| |
|
|
| |
| |
| m, z = self.input_embedder( |
| feats["target_feat"], |
| feats["residue_index"], |
| feats["msa_feat"], |
| ) |
|
|
| |
| if None in [m_1_prev, z_prev, x_prev]: |
| |
| m_1_prev = m.new_zeros( |
| (*batch_dims, n, self.config.input_embedder.c_m), |
| requires_grad=False, |
| ) |
|
|
| |
| z_prev = z.new_zeros( |
| (*batch_dims, n, n, self.config.input_embedder.c_z), |
| requires_grad=False, |
| ) |
|
|
| |
| x_prev = z.new_zeros( |
| (*batch_dims, n, residue_constants.atom_type_num, 3), |
| requires_grad=False, |
| ) |
|
|
| x_prev = pseudo_beta_fn( |
| feats["aatype"], x_prev, None |
| ).to(dtype=z.dtype) |
|
|
| |
| |
| m_1_prev_emb, z_prev_emb = self.recycling_embedder( |
| m_1_prev, |
| z_prev, |
| x_prev, |
| ) |
|
|
| |
| |
| |
| |
| if(not _recycle): |
| m_1_prev_emb *= 0 |
| z_prev_emb *= 0 |
|
|
| |
| m[..., 0, :, :] += m_1_prev_emb |
|
|
| |
| z += z_prev_emb |
|
|
| |
| del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb |
|
|
| |
| if self.config.template.enabled: |
| template_feats = { |
| k: v for k, v in feats.items() if k.startswith("template_") |
| } |
| template_embeds = self.embed_templates( |
| template_feats, |
| z, |
| pair_mask.to(dtype=z.dtype), |
| no_batch_dims, |
| ) |
|
|
| |
| z = z + template_embeds["template_pair_embedding"] |
|
|
| if self.config.template.embed_angles: |
| |
| m = torch.cat( |
| [m, template_embeds["template_angle_embedding"]], |
| dim=-3 |
| ) |
|
|
| |
| torsion_angles_mask = feats["template_torsion_angles_mask"] |
| msa_mask = torch.cat( |
| [feats["msa_mask"], torsion_angles_mask[..., 2]], |
| dim=-2 |
| ) |
|
|
| |
| if self.config.extra_msa.enabled: |
| |
| a = self.extra_msa_embedder(build_extra_msa_feat(feats)) |
|
|
| |
| z = self.extra_msa_stack( |
| a, |
| z, |
| msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype), |
| chunk_size=self.globals.chunk_size, |
| pair_mask=pair_mask.to(dtype=z.dtype), |
| _mask_trans=self.config._mask_trans, |
| ) |
|
|
| |
| |
| |
| |
| m, z, s = self.evoformer( |
| m, |
| z, |
| msa_mask=msa_mask.to(dtype=m.dtype), |
| pair_mask=pair_mask.to(dtype=z.dtype), |
| chunk_size=self.globals.chunk_size, |
| _mask_trans=self.config._mask_trans, |
| ) |
|
|
| outputs["msa"] = m[..., :n_seq, :, :] |
| outputs["pair"] = z |
| outputs["single"] = s |
|
|
| |
| outputs["sm"] = self.structure_module( |
| s, |
| z, |
| feats["aatype"], |
| mask=feats["seq_mask"].to(dtype=s.dtype), |
| ) |
| outputs["final_atom_positions"] = atom14_to_atom37( |
| outputs["sm"]["positions"][-1], feats |
| ) |
| outputs["final_atom_mask"] = feats["atom37_atom_exists"] |
| outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1] |
|
|
| |
|
|
| |
| m_1_prev = m[..., 0, :, :] |
|
|
| |
| z_prev = z |
|
|
| |
| x_prev = outputs["final_atom_positions"] |
|
|
| return outputs, m_1_prev, z_prev, x_prev |
|
|
| def _disable_activation_checkpointing(self): |
| self.template_pair_stack.blocks_per_ckpt = None |
| self.evoformer.blocks_per_ckpt = None |
|
|
| for b in self.extra_msa_stack.blocks: |
| b.ckpt = False |
|
|
| def _enable_activation_checkpointing(self): |
| self.template_pair_stack.blocks_per_ckpt = ( |
| self.config.template.template_pair_stack.blocks_per_ckpt |
| ) |
| self.evoformer.blocks_per_ckpt = ( |
| self.config.evoformer_stack.blocks_per_ckpt |
| ) |
|
|
| for b in self.extra_msa_stack.blocks: |
| b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt |
|
|
| def forward(self, batch): |
| """ |
| Args: |
| batch: |
| Dictionary of arguments outlined in Algorithm 2. Keys must |
| include the official names of the features in the |
| supplement subsection 1.2.9. |
| |
| The final dimension of each input must have length equal to |
| the number of recycling iterations. |
| |
| Features (without the recycling dimension): |
| |
| "aatype" ([*, N_res]): |
| Contrary to the supplement, this tensor of residue |
| indices is not one-hot. |
| "target_feat" ([*, N_res, C_tf]) |
| One-hot encoding of the target sequence. C_tf is |
| config.model.input_embedder.tf_dim. |
| "residue_index" ([*, N_res]) |
| Tensor whose final dimension consists of |
| consecutive indices from 0 to N_res. |
| "msa_feat" ([*, N_seq, N_res, C_msa]) |
| MSA features, constructed as in the supplement. |
| C_msa is config.model.input_embedder.msa_dim. |
| "seq_mask" ([*, N_res]) |
| 1-D sequence mask |
| "msa_mask" ([*, N_seq, N_res]) |
| MSA mask |
| "pair_mask" ([*, N_res, N_res]) |
| 2-D pair mask |
| "extra_msa_mask" ([*, N_extra, N_res]) |
| Extra MSA mask |
| "template_mask" ([*, N_templ]) |
| Template mask (on the level of templates, not |
| residues) |
| "template_aatype" ([*, N_templ, N_res]) |
| Tensor of template residue indices (indices greater |
| than 19 are clamped to 20 (Unknown)) |
| "template_all_atom_positions" |
| ([*, N_templ, N_res, 37, 3]) |
| Template atom coordinates in atom37 format |
| "template_all_atom_mask" ([*, N_templ, N_res, 37]) |
| Template atom coordinate mask |
| "template_pseudo_beta" ([*, N_templ, N_res, 3]) |
| Positions of template carbon "pseudo-beta" atoms |
| (i.e. C_beta for all residues but glycine, for |
| for which C_alpha is used instead) |
| "template_pseudo_beta_mask" ([*, N_templ, N_res]) |
| Pseudo-beta mask |
| """ |
| |
| m_1_prev, z_prev, x_prev = None, None, None |
|
|
| |
| is_grad_enabled = torch.is_grad_enabled() |
| self._disable_activation_checkpointing() |
|
|
| |
| num_iters = batch["aatype"].shape[-1] |
| for cycle_no in range(num_iters): |
| |
| fetch_cur_batch = lambda t: t[..., cycle_no] |
| feats = tensor_tree_map(fetch_cur_batch, batch) |
|
|
| |
| is_final_iter = cycle_no == (num_iters - 1) |
| with torch.set_grad_enabled(is_grad_enabled and is_final_iter): |
| if is_final_iter: |
| self._enable_activation_checkpointing() |
| |
| if torch.is_autocast_enabled(): |
| torch.clear_autocast_cache() |
|
|
| |
| outputs, m_1_prev, z_prev, x_prev = self.iteration( |
| feats, |
| m_1_prev, |
| z_prev, |
| x_prev, |
| _recycle=(num_iters > 1) |
| ) |
|
|
| |
| outputs.update(self.aux_heads(outputs)) |
|
|
| return outputs |
|
|