Feature Extraction
Transformers
Safetensors
esmfold2
biology
protein-structure
multimodal-protein-model
custom_code
Instructions to use Synthyra/ESMFold2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/ESMFold2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/ESMFold2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/ESMFold2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import TypeVar | |
| import numpy as np | |
| import torch | |
| from torch import Tensor | |
| from . import esmfold2_residue_constants as RC | |
| from .esmfold2_affine3d import Affine3D | |
| ArrayOrTensor = TypeVar("ArrayOrTensor", np.ndarray, Tensor) | |
| def atom3_to_backbone_frames(bb_positions: torch.Tensor) -> Affine3D: | |
| N, CA, C = bb_positions.unbind(dim=-2) | |
| return Affine3D.from_graham_schmidt(C, CA, N) | |
| def index_by_atom_name( | |
| atom37: ArrayOrTensor, atom_names: str | list[str], dim: int = -2 | |
| ) -> ArrayOrTensor: | |
| squeeze = False | |
| if isinstance(atom_names, str): | |
| atom_names = [atom_names] | |
| squeeze = True | |
| indices = [RC.atom_order[atom_name] for atom_name in atom_names] | |
| dim = dim % atom37.ndim | |
| index = tuple(slice(None) if dim != i else indices for i in range(atom37.ndim)) | |
| result = atom37[index] # type: ignore | |
| if squeeze: | |
| result = result.squeeze(dim) | |
| return result | |
| def get_protein_normalization_frame(coords: Tensor) -> Affine3D: | |
| """Given a set of coordinates for a protein, compute a single frame that can be used to normalize the coordinates. | |
| Specifically, we compute the average position of the N, CA, and C atoms use those 3 points to construct a frame | |
| using the Gram-Schmidt algorithm. The average CA position is used as the origin of the frame. | |
| Args: | |
| coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates | |
| Returns: | |
| Affine3D: tensor of Affine3D frame | |
| """ | |
| bb_coords = index_by_atom_name(coords, ["N", "CA", "C"], dim=-2) | |
| coord_mask = torch.all(torch.all(torch.isfinite(bb_coords), dim=-1), dim=-1) | |
| average_position_per_n_ca_c = bb_coords.masked_fill( | |
| ~coord_mask[..., None, None], 0 | |
| ).sum(-3) / (coord_mask.sum(-1)[..., None, None] + 1e-8) | |
| frame = atom3_to_backbone_frames(average_position_per_n_ca_c.float()) | |
| return frame | |
| def apply_frame_to_coords(coords: Tensor, frame: Affine3D) -> Tensor: | |
| """Given a set of coordinates and a single frame, apply the frame to the coordinates. | |
| Args: | |
| coords (torch.FloatTensor): [L, 37, 3] tensor of coordinates | |
| frame (Affine3D): Affine3D frame | |
| Returns: | |
| torch.FloatTensor: [L, 37, 3] tensor of transformed coordinates | |
| """ | |
| coords_trans_rot = frame[..., None, None].invert().apply(coords) | |
| # only transform coordinates with frame that have a valid rotation | |
| valid_frame = frame.trans.norm(dim=-1) > 0 | |
| is_inf = torch.isinf(coords) | |
| coords = coords_trans_rot.where(valid_frame[..., None, None, None], coords) | |
| coords.masked_fill_(is_inf, torch.inf) | |
| return coords | |
| def normalize_coordinates(coords: Tensor) -> Tensor: | |
| return apply_frame_to_coords(coords, get_protein_normalization_frame(coords)) | |