Feature Extraction
Transformers
Safetensors
esmfold2
biology
protein-structure
multimodal-protein-model
custom_code
Instructions to use Synthyra/ESMFold2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/ESMFold2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Synthyra/ESMFold2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/ESMFold2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from dataclasses import dataclass | |
| from typing import Any, Sequence, TypeAlias, Union | |
| import numpy as np | |
| from .esmfold2_msa import MSA | |
| # fmt: off | |
| MSAInput: TypeAlias = Union[ | |
| MSA, | |
| None, | |
| ] | |
| # fmt: on | |
| class Modification: | |
| position: int # zero-indexed | |
| ccd: str | |
| smiles: str | None = None # TODO(mlee): add smiles support | |
| class ProteinInput: | |
| id: str | list[str] | |
| sequence: str | |
| modifications: list[Modification] | None = None | |
| msa: MSAInput = None | |
| class RNAInput: | |
| id: str | list[str] | |
| sequence: str | |
| modifications: list[Modification] | None = None | |
| class DNAInput: | |
| id: str | list[str] | |
| sequence: str | |
| modifications: list[Modification] | None = None | |
| class LigandInput: | |
| id: str | list[str] | |
| smiles: str | None = None | |
| ccd: list[str] | None = None | |
| class DistogramConditioning: | |
| chain_id: str | |
| distogram: np.ndarray | |
| class PocketConditioning: | |
| binder_chain_id: str | |
| contacts: list[tuple[str, int]] | |
| class CovalentBond: | |
| chain_id1: str | |
| res_idx1: int | |
| atom_idx1: int | |
| chain_id2: str | |
| res_idx2: int | |
| atom_idx2: int | |
| class StructurePredictionInput: | |
| sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput] | |
| pocket: PocketConditioning | None = None | |
| distogram_conditioning: list[DistogramConditioning] | None = None | |
| covalent_bonds: list[CovalentBond] | None = None | |
| def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput): | |
| def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]: | |
| chain_data: dict[str, Any] = { | |
| "sequence": seq_input.sequence, | |
| "id": seq_input.id, | |
| "type": chain_type, | |
| } | |
| if hasattr(seq_input, "modifications") and seq_input.modifications: | |
| mods = [ | |
| {"position": mod.position, "ccd": mod.ccd} | |
| for mod in seq_input.modifications | |
| ] | |
| chain_data["modifications"] = mods | |
| if not hasattr(seq_input, "msa"): | |
| pass | |
| elif seq_input.msa is None: | |
| chain_data["msa"] = None | |
| elif isinstance(seq_input.msa, MSA): | |
| chain_data["msa"] = {"sequences": seq_input.msa.sequences} | |
| else: | |
| error_msg = f"MSA must be None or MSA. Got {seq_input.msa} instead." | |
| raise AttributeError(error_msg) | |
| return chain_data | |
| sequences = [] | |
| for seq_input in all_atom_input.sequences: | |
| if isinstance(seq_input, ProteinInput): | |
| sequences.append(create_chain_data(seq_input, "protein")) | |
| elif isinstance(seq_input, RNAInput): | |
| sequences.append(create_chain_data(seq_input, "rna")) | |
| elif isinstance(seq_input, DNAInput): | |
| sequences.append(create_chain_data(seq_input, "dna")) | |
| elif isinstance(seq_input, LigandInput): | |
| sequences.append( | |
| { | |
| "smiles": seq_input.smiles, | |
| "id": seq_input.id, | |
| "ccd": seq_input.ccd, | |
| "type": "ligand", | |
| } | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported sequence input type: {type(seq_input)}") | |
| result: dict[str, Any] = {"sequences": sequences} | |
| if all_atom_input.covalent_bonds is not None: | |
| result["covalent_bonds"] = [ | |
| { | |
| "chain_id1": bond.chain_id1, | |
| "res_idx1": bond.res_idx1, | |
| "atom_idx1": bond.atom_idx1, | |
| "chain_id2": bond.chain_id2, | |
| "res_idx2": bond.res_idx2, | |
| "atom_idx2": bond.atom_idx2, | |
| } | |
| for bond in all_atom_input.covalent_bonds | |
| ] | |
| if all_atom_input.pocket is not None: | |
| result["pocket"] = { | |
| "binder_chain_id": all_atom_input.pocket.binder_chain_id, | |
| "contacts": all_atom_input.pocket.contacts, | |
| } | |
| if all_atom_input.distogram_conditioning is not None: | |
| result["distogram_conditioning"] = [ | |
| {"chain_id": disto.chain_id, "distogram": disto.distogram.tolist()} | |
| for disto in all_atom_input.distogram_conditioning | |
| ] | |
| return result | |
| def deserialize_structure_prediction_input( | |
| data: dict[str, Any], | |
| ) -> StructurePredictionInput: | |
| """Inverse of :func:`serialize_structure_prediction_input`. | |
| Reconstructs a :class:`StructurePredictionInput` from the JSON-safe dict | |
| produced by ``serialize_structure_prediction_input``. Values round-trip; | |
| ``DistogramConditioning.distogram`` dtype follows from JSON (``int64`` | |
| for integer entries, ``float64`` for floats) — cast back to the original | |
| dtype if downstream code requires a specific one. | |
| """ | |
| def _mods(chain: dict[str, Any]) -> list[Modification] | None: | |
| raw = chain.get("modifications") | |
| if not raw: | |
| return None | |
| return [Modification(position=m["position"], ccd=m["ccd"]) for m in raw] | |
| def _msa(chain: dict[str, Any]) -> MSAInput: | |
| if "msa" not in chain or chain["msa"] is None: | |
| return None | |
| msa_blk = chain["msa"] | |
| if isinstance(msa_blk, str): | |
| raise ValueError(f"Unexpected MSA string value: {msa_blk!r}") | |
| return MSA.from_sequences(msa_blk["sequences"]) | |
| sequences: list[ProteinInput | RNAInput | DNAInput | LigandInput] = [] | |
| for chain in data["sequences"]: | |
| t = chain["type"] | |
| if t == "protein": | |
| sequences.append( | |
| ProteinInput( | |
| id=chain["id"], | |
| sequence=chain["sequence"], | |
| modifications=_mods(chain), | |
| msa=_msa(chain), | |
| ) | |
| ) | |
| elif t == "rna": | |
| sequences.append( | |
| RNAInput( | |
| id=chain["id"], | |
| sequence=chain["sequence"], | |
| modifications=_mods(chain), | |
| ) | |
| ) | |
| elif t == "dna": | |
| sequences.append( | |
| DNAInput( | |
| id=chain["id"], | |
| sequence=chain["sequence"], | |
| modifications=_mods(chain), | |
| ) | |
| ) | |
| elif t == "ligand": | |
| sequences.append( | |
| LigandInput( | |
| id=chain["id"], smiles=chain.get("smiles"), ccd=chain.get("ccd") | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported sequence type: {t!r}") | |
| pocket: PocketConditioning | None = None | |
| if (pocket_blk := data.get("pocket")) is not None: | |
| pocket = PocketConditioning( | |
| binder_chain_id=pocket_blk["binder_chain_id"], | |
| contacts=[tuple(c) for c in pocket_blk["contacts"]], | |
| ) | |
| distogram_conditioning: list[DistogramConditioning] | None = None | |
| if (disto_blk := data.get("distogram_conditioning")) is not None: | |
| distogram_conditioning = [ | |
| DistogramConditioning( | |
| chain_id=d["chain_id"], distogram=np.asarray(d["distogram"]) | |
| ) | |
| for d in disto_blk | |
| ] | |
| covalent_bonds: list[CovalentBond] | None = None | |
| if (bonds_blk := data.get("covalent_bonds")) is not None: | |
| covalent_bonds = [ | |
| CovalentBond( | |
| chain_id1=b["chain_id1"], | |
| res_idx1=b["res_idx1"], | |
| atom_idx1=b["atom_idx1"], | |
| chain_id2=b["chain_id2"], | |
| res_idx2=b["res_idx2"], | |
| atom_idx2=b["atom_idx2"], | |
| ) | |
| for b in bonds_blk | |
| ] | |
| return StructurePredictionInput( | |
| sequences=sequences, | |
| pocket=pocket, | |
| distogram_conditioning=distogram_conditioning, | |
| covalent_bonds=covalent_bonds, | |
| ) | |