from __future__ import annotations import os from collections import defaultdict from contextlib import nullcontext from dataclasses import is_dataclass from io import BytesIO from typing import ( Any, ContextManager, Generator, Iterable, Protocol, Sequence, TypeVar, runtime_checkable, ) from warnings import warn import huggingface_hub import numpy as np import torch import zstd from .esmfold2_constants_esm3 import CHAIN_BREAK_STR from .esmfold2_utils_types import FunctionAnnotation MAX_SUPPORTED_DISTANCE = 1e6 TSequence = TypeVar("TSequence", bound=Sequence) @runtime_checkable class Concatable(Protocol): @classmethod def concat(cls, objs: list[Concatable]) -> Concatable: ... def slice_python_object_as_numpy( obj: TSequence, idx: int | list[int] | slice | np.ndarray ) -> TSequence: """ Slice a python object (like a list, string, or tuple) as if it was a numpy object. Example: >>> obj = "ABCDE" >>> slice_python_object_as_numpy(obj, [1, 3, 4]) "BDE" >>> obj = [1, 2, 3, 4, 5] >>> slice_python_object_as_numpy(obj, np.arange(5) < 3) [1, 2, 3] """ if np.isscalar(idx): idx = [int(idx)] # type: ignore if isinstance(idx, np.ndarray) and idx.dtype == bool: sliced_obj = [obj[i] for i in np.where(idx)[0]] elif isinstance(idx, slice): sliced_obj = obj[idx] else: sliced_obj = [obj[i] for i in idx] # type: ignore match obj, sliced_obj: case str(), list(): sliced_obj = "".join(sliced_obj) case _: sliced_obj = obj.__class__(sliced_obj) # type: ignore return sliced_obj # type: ignore def slice_any_object( obj: TSequence, idx: int | list[int] | slice | np.ndarray ) -> TSequence: """ Slice a arbitrary object (like a list, string, or tuple) as if it was a numpy object. Similar to `slice_python_object_as_numpy`, but detects if it's a numpy array or Tensor and uses the existing slice method if so. If the object is a dataclass, it will simply apply the index to the object, under the assumption that the object has correcty implemented numpy indexing. Example: >>> obj = "ABCDE" >>> slice_any_object(obj, [1, 3, 4]) "BDE" >>> obj = np.array([1, 2, 3, 4, 5]) >>> slice_any_object(obj, np.arange(5) < 3) np.array([1, 2, 3]) >>> obj = ProteinChain.from_rcsb("1a3a", "A") >>> slice_any_object(obj, np.arange(len(obj)) < 10) # ProteinChain w/ length 10 """ if isinstance(obj, (np.ndarray, torch.Tensor)): return obj[idx] # type: ignore elif is_dataclass(obj): # if passing a dataclass, assume it implements a custom slice return obj[idx] # type: ignore else: return slice_python_object_as_numpy(obj, idx) def rbf(values, v_min, v_max, n_bins=16): """ Returns RBF encodings in a new dimension at the end. """ rbf_centers = torch.linspace( v_min, v_max, n_bins, device=values.device, dtype=values.dtype ) rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1]) rbf_std = (v_max - v_min) / n_bins z = (values.unsqueeze(-1) - rbf_centers) / rbf_std return torch.exp(-(z**2)) def batched_gather(data, inds, dim=0, no_batch_dims=0): ranges = [] for i, s in enumerate(data.shape[:no_batch_dims]): r = torch.arange(s) r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) ranges.append(r) remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds ranges.extend(remaining_dims) return data[ranges] def node_gather(s: torch.Tensor, edges: torch.Tensor) -> torch.Tensor: return batched_gather(s.unsqueeze(-3), edges, -2, no_batch_dims=len(s.shape) - 1) def knn_graph( coords: torch.Tensor, coord_mask: torch.Tensor, padding_mask: torch.Tensor, sequence_id: torch.Tensor, *, no_knn: int, ): L = coords.shape[-2] num_by_dist = min(no_knn, L) device = coords.device coords = coords.nan_to_num() coord_mask = ~(coord_mask[..., None, :] & coord_mask[..., :, None]) padding_pairwise_mask = padding_mask[..., None, :] | padding_mask[..., :, None] if sequence_id is not None: padding_pairwise_mask |= torch.unsqueeze(sequence_id, 1) != torch.unsqueeze( sequence_id, 2 ) dists = (coords.unsqueeze(-2) - coords.unsqueeze(-3)).norm(dim=-1) arange = torch.arange(L, device=device) seq_dists = (arange.unsqueeze(-1) - arange.unsqueeze(-2)).abs() # We only support up to a certain distance, above that, we use sequence distance # instead. This is so that when a large portion of the structure is masked out, # the edges are built according to sequence distance. max_dist = MAX_SUPPORTED_DISTANCE if not (dists[~coord_mask] < max_dist).all(): raise ValueError( f"Coordinate pairwise distances exceed max supported distance ({max_dist}). " ) struct_then_seq_dist = ( seq_dists.to(dists.dtype) .mul(1e2) .add(max_dist) .where(coord_mask, dists) .masked_fill(padding_pairwise_mask, torch.inf) ) dists, edges = struct_then_seq_dist.sort(dim=-1, descending=False) # This is a L x L tensor, where we index by rows first, # and columns are the edges we should pick. chosen_edges = edges[..., :num_by_dist] chosen_mask = dists[..., :num_by_dist].isfinite() return chosen_edges, chosen_mask def stack_variable_length_tensors( sequences: Sequence[torch.Tensor], constant_value: int | float = 0, dtype: torch.dtype | None = None, ) -> torch.Tensor: """Automatically stack tensors together, padding variable lengths with the value in constant_value. Handles an arbitrary number of dimensions. Examples: >>> tensor1, tensor2 = torch.ones([2]), torch.ones([5]) >>> stack_variable_length_tensors(tensor1, tensor2) tensor of shape [2, 5]. First row is [1, 1, 0, 0, 0]. Second row is all ones. >>> tensor1, tensor2 = torch.ones([2, 4]), torch.ones([5, 3]) >>> stack_variable_length_tensors(tensor1, tensor2) tensor of shape [2, 5, 4] """ batch_size = len(sequences) shape = [batch_size] + np.max([seq.shape for seq in sequences], 0).tolist() if dtype is None: dtype = sequences[0].dtype device = sequences[0].device array = torch.full(shape, constant_value, dtype=dtype, device=device) for arr, seq in zip(array, sequences): arrslice = tuple(slice(dim) for dim in seq.shape) arr[arrslice] = seq return array def binpack( tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float ): """ Args: tensor (Tensor): [B, L, ...] Returns: Tensor: [B_binpacked, L_binpacked, ...] """ if sequence_id is None: return tensor num_sequences = sequence_id.max(dim=-1).values + 1 dims = sequence_id.shape + tensor.shape[2:] output_tensor = torch.full( dims, fill_value=pad_value, dtype=tensor.dtype, device=tensor.device ) idx = 0 for batch_idx, (batch_seqid, batch_num_sequences) in enumerate( zip(sequence_id, num_sequences) ): for seqid in range(batch_num_sequences): mask = batch_seqid == seqid output_tensor[batch_idx, mask] = tensor[idx, : mask.sum()] idx += 1 return output_tensor def unbinpack( tensor: torch.Tensor, sequence_id: torch.Tensor | None, pad_value: int | float ): """ Args: tensor (Tensor): [B, L, ...] Returns: Tensor: [B_unbinpacked, L_unbinpack, ...] """ if sequence_id is None: return tensor unpacked_tensors = [] num_sequences = sequence_id.max(dim=-1).values + 1 for batch_idx, (batch_seqid, batch_num_sequences) in enumerate( zip(sequence_id, num_sequences) ): for seqid in range(batch_num_sequences): mask = batch_seqid == seqid unpacked = tensor[batch_idx, mask] unpacked_tensors.append(unpacked) return stack_variable_length_tensors(unpacked_tensors, pad_value) def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ignore """ Returns an autocast context manager that disables downcasting by AMP. Args: device_type: The device type ('cpu' or 'cuda') Returns: An autocast context manager with the specified behavior. """ if device_type == "cpu": return torch.amp.autocast(device_type, enabled=False) # type: ignore elif device_type == "mps": # For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast. return nullcontext() elif device_type == "cuda": return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore else: raise ValueError(f"Unsupported device type: {device_type}") def merge_ranges(ranges: list[range], merge_gap_max: int | None = None) -> list[range]: """Merge overlapping ranges into sorted, non-overlapping segments. Args: ranges: collection of ranges to merge. merge_gap_max: optionally merge neighboring ranges that are separated by a gap no larger than this size. Returns: non-overlapping ranges merged from the inputs, sorted by position. """ ranges = sorted(ranges, key=lambda r: r.start) merge_gap_max = merge_gap_max if merge_gap_max is not None else 0 assert merge_gap_max >= 0, f"Invalid merge_gap_max: {merge_gap_max}" merged = [] for r in ranges: if not merged: merged.append(r) else: last = merged[-1] if last.stop + merge_gap_max >= r.start: merged[-1] = range(last.start, max(last.stop, r.stop)) else: merged.append(r) return merged def merge_annotations( annotations: list[FunctionAnnotation], merge_gap_max: int | None = None ) -> list[FunctionAnnotation]: """Merges annotations into non-overlapping segments. Args: annotations: annotations to merge. merge_gap_max: optionally merge neighboring ranges that are separated by a gap no larger than this size. Returns: non-overlapping annotations with gaps merged. """ grouped: dict[str, list[range]] = defaultdict(list) for a in annotations: # +1 since FunctionAnnotation.end is inlcusive. grouped[a.label].append(range(a.start, a.end + 1)) merged = [] for label, ranges in grouped.items(): merged_ranges = merge_ranges(ranges, merge_gap_max=merge_gap_max) for range_ in merged_ranges: annotation = FunctionAnnotation( label=label, start=range_.start, end=range_.stop - 1, # convert range.stop exclusive -> inclusive. ) merged.append(annotation) return merged def replace_inf(data): if data is None: return None array = np.asarray(data, dtype=np.float32) array = np.where(np.isinf(array), 1000, array) return array.tolist() def maybe_tensor(x, convert_none_to_nan: bool = False) -> torch.Tensor | None: if x is None: return None if isinstance(x, torch.Tensor): return x if isinstance(x, list) and all(isinstance(t, torch.Tensor) for t in x): return torch.stack(x) if convert_none_to_nan: x = np.asarray(x, dtype=np.float32) x = np.where(x is None, np.nan, x) return torch.tensor(x) def maybe_list(x, convert_nan_to_none: bool = False) -> list | None: if x is None: return None if not convert_nan_to_none: return x.tolist() # Handle both torch.tensor and np.ndarray input. if isinstance(x, torch.Tensor): nan_mask = torch.isnan(x).cpu().numpy() np_arr = x.cpu().numpy().astype(object) elif isinstance(x, np.ndarray): nan_mask = np.isnan(x) np_arr = x.astype(object) else: raise TypeError("maybe_list can only work with torch.tensor or np.ndarray.") np_arr[nan_mask] = None return np_arr.tolist() def huggingfacehub_login(): """Authenticates with the Hugging Face Hub using the HF_TOKEN environment variable, else by prompting the user""" token = os.environ.get("HF_TOKEN") huggingface_hub.login(token=token) def get_chainbreak_boundaries_from_sequence(sequence: Sequence[str]) -> np.ndarray: chain_boundaries = [0] for i, aa in enumerate(sequence): if aa == CHAIN_BREAK_STR: if i == (len(sequence) - 1): raise ValueError( "Encountered chain break token at end of sequence, this is unexpected." ) if i == (len(sequence) - 2): warn( "Encountered chain break token at penultimate position, this is unexpected." ) chain_boundaries.append(i) chain_boundaries.append(i + 1) chain_boundaries.append(len(sequence)) assert len(chain_boundaries) % 2 == 0 chain_boundaries = np.array(chain_boundaries).reshape(-1, 2) return chain_boundaries def deserialize_tensors(b: bytes) -> Any: buf = BytesIO(zstd.ZSTD_uncompress(b)) d = torch.load(buf, map_location="cpu", weights_only=False) return d def join_lists( lists: Sequence[Sequence[Any]], separator: Sequence[Any] | None = None ) -> list[Any]: """Joins multiple lists with separator element. Like str.join but for lists. Example: [[1, 2], [3], [4]], separator=[0] -> [1, 2, 0, 3, 0, 4] Args: lists: Lists of elements to chain separator: separators to intsert between chained output. Returns: Joined lists. """ if not lists: return [] joined = [] joined.extend(lists[0]) for l in lists[1:]: if separator: joined.extend(separator) joined.extend(l) return joined def iterate_with_intermediate( lists: Iterable, intermediate ) -> Generator[Any, None, None]: """ Iterate over the iterable, yielding the intermediate value between every element of the intermediate. Useful for joining objects with separator tokens. """ it = iter(lists) yield next(it) for l in it: yield intermediate yield l def concat_objects(objs: Sequence[Any], separator: Any | None = None): """ Concat objects with each other using a separator token. Supports: - Concatable (objects that implement `concat` classmethod) - strings - lists - numpy arrays - torch Tensors Example: >>> foo = "abc" >>> bar = "def" >>> concat_objects([foo, bar], "|") "abc|def" """ match objs[0]: case Concatable(): return objs[0].__class__.concat(objs) # type: ignore case str(): assert isinstance( separator, str ), "Trying to join strings but separator is not a string" return separator.join(objs) case list(): if separator is not None: return join_lists(objs, [separator]) else: return join_lists(objs) case np.ndarray(): if separator is not None: return np.concatenate( list(iterate_with_intermediate(objs, np.array([separator]))) ) else: return np.concatenate(objs) case torch.Tensor(): if separator is not None: return torch.cat( list(iterate_with_intermediate(objs, torch.tensor([separator]))) ) else: return torch.cat(objs) # type: ignore case _: raise TypeError(type(objs[0]))