ESMFold2 / esmfold2_input_builder.py
lhallee's picture
Upload folder using huggingface_hub
7e8d2fc verified
from dataclasses import dataclass
from typing import Any, Sequence, TypeAlias, Union
import numpy as np
from .esmfold2_msa import MSA
# fmt: off
MSAInput: TypeAlias = Union[
MSA,
None,
]
# fmt: on
@dataclass
class Modification:
position: int # zero-indexed
ccd: str
smiles: str | None = None # TODO(mlee): add smiles support
@dataclass
class ProteinInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
msa: MSAInput = None
@dataclass
class RNAInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
@dataclass
class DNAInput:
id: str | list[str]
sequence: str
modifications: list[Modification] | None = None
@dataclass
class LigandInput:
id: str | list[str]
smiles: str | None = None
ccd: list[str] | None = None
@dataclass
class DistogramConditioning:
chain_id: str
distogram: np.ndarray
@dataclass
class PocketConditioning:
binder_chain_id: str
contacts: list[tuple[str, int]]
@dataclass
class CovalentBond:
chain_id1: str
res_idx1: int
atom_idx1: int
chain_id2: str
res_idx2: int
atom_idx2: int
@dataclass
class StructurePredictionInput:
sequences: Sequence[ProteinInput | RNAInput | DNAInput | LigandInput]
pocket: PocketConditioning | None = None
distogram_conditioning: list[DistogramConditioning] | None = None
covalent_bonds: list[CovalentBond] | None = None
def serialize_structure_prediction_input(all_atom_input: StructurePredictionInput):
def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]:
chain_data: dict[str, Any] = {
"sequence": seq_input.sequence,
"id": seq_input.id,
"type": chain_type,
}
if hasattr(seq_input, "modifications") and seq_input.modifications:
mods = [
{"position": mod.position, "ccd": mod.ccd}
for mod in seq_input.modifications
]
chain_data["modifications"] = mods
if not hasattr(seq_input, "msa"):
pass
elif seq_input.msa is None:
chain_data["msa"] = None
elif isinstance(seq_input.msa, MSA):
chain_data["msa"] = {"sequences": seq_input.msa.sequences}
else:
error_msg = f"MSA must be None or MSA. Got {seq_input.msa} instead."
raise AttributeError(error_msg)
return chain_data
sequences = []
for seq_input in all_atom_input.sequences:
if isinstance(seq_input, ProteinInput):
sequences.append(create_chain_data(seq_input, "protein"))
elif isinstance(seq_input, RNAInput):
sequences.append(create_chain_data(seq_input, "rna"))
elif isinstance(seq_input, DNAInput):
sequences.append(create_chain_data(seq_input, "dna"))
elif isinstance(seq_input, LigandInput):
sequences.append(
{
"smiles": seq_input.smiles,
"id": seq_input.id,
"ccd": seq_input.ccd,
"type": "ligand",
}
)
else:
raise ValueError(f"Unsupported sequence input type: {type(seq_input)}")
result: dict[str, Any] = {"sequences": sequences}
if all_atom_input.covalent_bonds is not None:
result["covalent_bonds"] = [
{
"chain_id1": bond.chain_id1,
"res_idx1": bond.res_idx1,
"atom_idx1": bond.atom_idx1,
"chain_id2": bond.chain_id2,
"res_idx2": bond.res_idx2,
"atom_idx2": bond.atom_idx2,
}
for bond in all_atom_input.covalent_bonds
]
if all_atom_input.pocket is not None:
result["pocket"] = {
"binder_chain_id": all_atom_input.pocket.binder_chain_id,
"contacts": all_atom_input.pocket.contacts,
}
if all_atom_input.distogram_conditioning is not None:
result["distogram_conditioning"] = [
{"chain_id": disto.chain_id, "distogram": disto.distogram.tolist()}
for disto in all_atom_input.distogram_conditioning
]
return result
def deserialize_structure_prediction_input(
data: dict[str, Any],
) -> StructurePredictionInput:
"""Inverse of :func:`serialize_structure_prediction_input`.
Reconstructs a :class:`StructurePredictionInput` from the JSON-safe dict
produced by ``serialize_structure_prediction_input``. Values round-trip;
``DistogramConditioning.distogram`` dtype follows from JSON (``int64``
for integer entries, ``float64`` for floats) — cast back to the original
dtype if downstream code requires a specific one.
"""
def _mods(chain: dict[str, Any]) -> list[Modification] | None:
raw = chain.get("modifications")
if not raw:
return None
return [Modification(position=m["position"], ccd=m["ccd"]) for m in raw]
def _msa(chain: dict[str, Any]) -> MSAInput:
if "msa" not in chain or chain["msa"] is None:
return None
msa_blk = chain["msa"]
if isinstance(msa_blk, str):
raise ValueError(f"Unexpected MSA string value: {msa_blk!r}")
return MSA.from_sequences(msa_blk["sequences"])
sequences: list[ProteinInput | RNAInput | DNAInput | LigandInput] = []
for chain in data["sequences"]:
t = chain["type"]
if t == "protein":
sequences.append(
ProteinInput(
id=chain["id"],
sequence=chain["sequence"],
modifications=_mods(chain),
msa=_msa(chain),
)
)
elif t == "rna":
sequences.append(
RNAInput(
id=chain["id"],
sequence=chain["sequence"],
modifications=_mods(chain),
)
)
elif t == "dna":
sequences.append(
DNAInput(
id=chain["id"],
sequence=chain["sequence"],
modifications=_mods(chain),
)
)
elif t == "ligand":
sequences.append(
LigandInput(
id=chain["id"], smiles=chain.get("smiles"), ccd=chain.get("ccd")
)
)
else:
raise ValueError(f"Unsupported sequence type: {t!r}")
pocket: PocketConditioning | None = None
if (pocket_blk := data.get("pocket")) is not None:
pocket = PocketConditioning(
binder_chain_id=pocket_blk["binder_chain_id"],
contacts=[tuple(c) for c in pocket_blk["contacts"]],
)
distogram_conditioning: list[DistogramConditioning] | None = None
if (disto_blk := data.get("distogram_conditioning")) is not None:
distogram_conditioning = [
DistogramConditioning(
chain_id=d["chain_id"], distogram=np.asarray(d["distogram"])
)
for d in disto_blk
]
covalent_bonds: list[CovalentBond] | None = None
if (bonds_blk := data.get("covalent_bonds")) is not None:
covalent_bonds = [
CovalentBond(
chain_id1=b["chain_id1"],
res_idx1=b["res_idx1"],
atom_idx1=b["atom_idx1"],
chain_id2=b["chain_id2"],
res_idx2=b["res_idx2"],
atom_idx2=b["atom_idx2"],
)
for b in bonds_blk
]
return StructurePredictionInput(
sequences=sequences,
pocket=pocket,
distogram_conditioning=distogram_conditioning,
covalent_bonds=covalent_bonds,
)