Image Feature Extraction
Transformers
Safetensors
esmfold2
biology
protein-structure
multimodal-protein-model
custom_code
Instructions to use Synthyra/ESMFold2-Fast with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/ESMFold2-Fast with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-feature-extraction", model="Synthyra/ESMFold2-Fast", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Synthyra/ESMFold2-Fast", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 7,232 Bytes
fb8a87c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | from abc import ABC, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import TypeVar
import numpy as np
from .esmfold2_misc import concat_objects, slice_any_object
T = TypeVar("T")
@dataclass(frozen=True)
class SequentialDataclass(ABC):
"""
This is a builder on a dataclass that allows for automatic slicing and concatenation.
When representing multimodal data, we often have multiple datatypes which have sequence dimensions that are the same (e.g. the length of the protein).
When applying a transformation like a crop, we want to apply this to all tensors at the same time (e.g. crop the sequence, structure, and function).
We also have some fields that are not sequential (like an id, or data source), which we don't want to crop.
The SequentialDataclass abstracts this cropping away, allowing you to define dataclasses that implement `__len__`, `__getitem__` and `concat` automatically.
This is done through the `metadata` field, which can take 3 values:
`sequence` (bool): True or False, tells the dataclass whether this field is a sequential type. Default: False.
`sequence_dim` (int): Which dimension is the sequential dimension (e.g. for a list of inverse folded sequences, we want to index each sequence in the list, not the list itself). Default: 0.
`join_token` (Any): What token to use to join when concatenating elements. Default: None.
Example:
@dataclass(frozen=True)
class Foo(SequentialDataclass):
id: str
sequence: str = field(metadata={"sequence": True, "join_token": "|"})
tensor: torch.Tensor = field(metadata={"sequence": True, "join_token": torch.nan})
def __len__(self):
# Must implement the __len__ method
return len(self.sequence)
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(5))
Foo(id='foo', sequence='ABCDE', tensor=tensor([ 0.0252, -0.3335, -0.5143, 0.0251, -1.0717]))
>>> foo[1:4]
Foo(id='foo', sequence='BCD', tensor=tensor([-0.3335, -0.5143, 0.0251]))
>>> foo[np.arange(5) < 3]
Foo(id='foo', sequence='ABC', tensor=tensor([ 0.0252, -0.3335, -0.5143]))
>>> Foo.concat([foo[:2], foo[3:]])
Foo(id='foo', sequence='AB|DE', tensor=tensor([ 0.0252, -0.3335, nan, 0.0251, -1.0717]))
# Trying to create a type where the sequence lengths do not match raises an error
>>> foo = Foo(id="foo", sequence="ABCDE", tensor=torch.randn(6))
ValueError: Mismatch in sequence length for field: tensor. Expected 5, received 6
"""
def __post_init__(self):
self._check_sequence_lengths_match()
@abstractmethod
def __len__(self):
raise NotImplementedError
def __getitem__(self, idx: int | list[int] | slice | np.ndarray):
updated_fields = {}
if isinstance(idx, int):
# make it so that things remain sequential
idx = [idx]
for fld in fields(self):
if fld.metadata.get("sequence", False):
# this is a sequence, should be the same length as all other sequences
sequence_dim = fld.metadata.get("sequence_dim", 0)
value = getattr(self, fld.name)
if value is None:
continue
match sequence_dim:
case 0:
# sequence is first dimension
value = getattr(self, fld.name)
value = slice_any_object(value, idx)
updated_fields[fld.name] = value
case 1:
new_value = [slice_any_object(item, idx) for item in value]
updated_fields[fld.name] = value.__class__(new_value)
case _:
raise NotImplementedError(
"Arbitrary slicing for different sequence length fields is not implemented"
)
return replace(self, **updated_fields)
def _check_sequence_lengths_match(self):
"""Checks if sequence lengths of all "sequence" fields match."""
for fld in fields(self):
if fld.metadata.get("sequence", False) and fld.name != "complex":
# this is a sequence, should be the same length as all other sequences
sequence_dim = fld.metadata.get("sequence_dim", 0)
value = getattr(self, fld.name)
if value is None:
continue
match sequence_dim:
case 0:
# sequence is first dimension
value = getattr(self, fld.name)
if len(value) != len(self):
raise ValueError(
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(value)}"
)
case 1:
for item in value:
if len(item) != len(self):
raise ValueError(
f"Mismatch in sequence length for field: {fld.name}. Expected {len(self)}, received {len(item)}"
)
case _:
raise NotImplementedError(
"Arbitrary matching for different sequence length fields is not implemented"
)
@classmethod
def concat(cls, items: list[T], **kwargs) -> T:
updated_fields = {}
for fld in fields(cls):
if fld.metadata.get("sequence", False):
# this is a sequence, should be the same length as all other sequences
sequence_dim = fld.metadata.get("sequence_dim", 0)
join_value = fld.metadata.get("join_token", None)
if getattr(items[0], fld.name) is None:
continue
values = [getattr(item, fld.name) for item in items]
match sequence_dim:
case 0:
# sequence is first dimension
value = concat_objects(values, join_value)
updated_fields[fld.name] = value
case 1:
new_value = [
concat_objects(item, join_value) for item in zip(*values)
]
updated_fields[fld.name] = getattr(
items[0], fld.name
).__class__(new_value)
case _:
raise NotImplementedError(
"Arbitrary joining for different sequence length fields is not implemented"
)
updated_fields.update(kwargs)
return replace(
items[0], # type: ignore
**updated_fields,
)
|