|
|
"""Protein dataset class.""" |
|
|
import os |
|
|
import pickle |
|
|
from pathlib import Path |
|
|
from glob import glob |
|
|
from typing import Optional, Sequence, List, Union |
|
|
from functools import lru_cache |
|
|
import tree |
|
|
|
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
|
|
|
from src.common import residue_constants, data_transforms, rigid_utils, protein |
|
|
|
|
|
|
|
|
CA_IDX = residue_constants.atom_order['CA'] |
|
|
DTYPE_MAPPING = { |
|
|
'aatype': torch.long, |
|
|
'atom_positions': torch.double, |
|
|
'atom_mask': torch.double, |
|
|
} |
|
|
|
|
|
|
|
|
class ProteinFeatureTransform: |
|
|
def __init__(self, |
|
|
unit: Optional[str] = 'angstrom', |
|
|
truncate_length: Optional[int] = None, |
|
|
strip_missing_residues: bool = True, |
|
|
recenter_and_scale: bool = True, |
|
|
eps: float = 1e-8, |
|
|
): |
|
|
if unit == 'angstrom': |
|
|
self.coordinate_scale = 1.0 |
|
|
elif unit in ('nm', 'nanometer'): |
|
|
self.coordiante_scale = 0.1 |
|
|
else: |
|
|
raise ValueError(f"Invalid unit: {unit}") |
|
|
|
|
|
if truncate_length is not None: |
|
|
assert truncate_length > 0, f"Invalid truncate_length: {truncate_length}" |
|
|
self.truncate_length = truncate_length |
|
|
|
|
|
self.strip_missing_residues = strip_missing_residues |
|
|
self.recenter_and_scale = recenter_and_scale |
|
|
self.eps = eps |
|
|
|
|
|
def __call__(self, chain_feats): |
|
|
chain_feats = self.patch_feats(chain_feats) |
|
|
|
|
|
if self.strip_missing_residues: |
|
|
chain_feats = self.strip_ends(chain_feats) |
|
|
|
|
|
if self.truncate_length is not None: |
|
|
chain_feats = self.random_truncate(chain_feats, max_len=self.truncate_length) |
|
|
|
|
|
|
|
|
if self.recenter_and_scale: |
|
|
chain_feats = self.recenter_and_scale_coords(chain_feats, coordinate_scale=self.coordinate_scale, eps=self.eps) |
|
|
|
|
|
|
|
|
chain_feats = self.map_to_tensors(chain_feats) |
|
|
|
|
|
chain_feats = self.protein_data_transform(chain_feats) |
|
|
|
|
|
|
|
|
return chain_feats |
|
|
|
|
|
@staticmethod |
|
|
def patch_feats(chain_feats): |
|
|
seq_mask = chain_feats['atom_mask'][:, CA_IDX] |
|
|
|
|
|
residue_idx = chain_feats['residue_index'] - np.min(chain_feats['residue_index']) |
|
|
patch_feats = { |
|
|
'seq_mask': seq_mask, |
|
|
'residue_mask': seq_mask, |
|
|
'residue_idx': residue_idx, |
|
|
'fixed_mask': np.zeros_like(seq_mask), |
|
|
'sc_ca_t': np.zeros(seq_mask.shape + (3, )), |
|
|
} |
|
|
chain_feats.update(patch_feats) |
|
|
return chain_feats |
|
|
|
|
|
@staticmethod |
|
|
def strip_ends(chain_feats): |
|
|
|
|
|
modeled_idx = np.where(chain_feats['aatype'] != 20)[0] |
|
|
min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx) |
|
|
chain_feats = tree.map_structure( |
|
|
lambda x: x[min_idx : (max_idx+1)], chain_feats) |
|
|
return chain_feats |
|
|
|
|
|
@staticmethod |
|
|
def random_truncate(chain_feats, max_len): |
|
|
L = chain_feats['aatype'].shape[0] |
|
|
if L > max_len: |
|
|
|
|
|
start = np.random.randint(0, L - max_len + 1) |
|
|
end = start + max_len |
|
|
chain_feats = tree.map_structure( |
|
|
lambda x: x[start : end], chain_feats) |
|
|
return chain_feats |
|
|
|
|
|
@staticmethod |
|
|
def map_to_tensors(chain_feats): |
|
|
chain_feats = {k: torch.as_tensor(v) for k,v in chain_feats.items()} |
|
|
|
|
|
for k, dtype in DTYPE_MAPPING.items(): |
|
|
if k in chain_feats: |
|
|
chain_feats[k] = chain_feats[k].type(dtype) |
|
|
return chain_feats |
|
|
|
|
|
@staticmethod |
|
|
def recenter_and_scale_coords(chain_feats, coordinate_scale, eps=1e-8): |
|
|
|
|
|
bb_pos = chain_feats['atom_positions'][:, CA_IDX] |
|
|
bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['seq_mask']) + eps) |
|
|
centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :] |
|
|
scaled_pos = centered_pos * coordinate_scale |
|
|
chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None] |
|
|
return chain_feats |
|
|
|
|
|
@staticmethod |
|
|
def protein_data_transform(chain_feats): |
|
|
chain_feats.update( |
|
|
{ |
|
|
"all_atom_positions": chain_feats["atom_positions"], |
|
|
"all_atom_mask": chain_feats["atom_mask"], |
|
|
} |
|
|
) |
|
|
chain_feats = data_transforms.atom37_to_frames(chain_feats) |
|
|
chain_feats = data_transforms.atom37_to_torsion_angles("")(chain_feats) |
|
|
chain_feats = data_transforms.get_backbone_frames(chain_feats) |
|
|
chain_feats = data_transforms.get_chi_angles(chain_feats) |
|
|
chain_feats = data_transforms.make_pseudo_beta("")(chain_feats) |
|
|
chain_feats = data_transforms.make_atom14_masks(chain_feats) |
|
|
chain_feats = data_transforms.make_atom14_positions(chain_feats) |
|
|
|
|
|
|
|
|
chain_feats.pop("all_atom_positions") |
|
|
chain_feats.pop("all_atom_mask") |
|
|
return chain_feats |
|
|
|
|
|
|
|
|
class MetadataFilter: |
|
|
def __init__(self, |
|
|
min_len: Optional[int] = None, |
|
|
max_len: Optional[int] = None, |
|
|
min_chains: Optional[int] = None, |
|
|
max_chains: Optional[int] = None, |
|
|
min_resolution: Optional[int] = None, |
|
|
max_resolution: Optional[int] = None, |
|
|
include_structure_method: Optional[List[str]] = None, |
|
|
include_oligomeric_detail: Optional[List[str]] = None, |
|
|
**kwargs, |
|
|
): |
|
|
self.min_len = min_len |
|
|
self.max_len = max_len |
|
|
self.min_chains = min_chains |
|
|
self.max_chains = max_chains |
|
|
self.min_resolution = min_resolution |
|
|
self.max_resolution = max_resolution |
|
|
self.include_structure_method = include_structure_method |
|
|
self.include_oligomeric_detail = include_oligomeric_detail |
|
|
|
|
|
def __call__(self, df): |
|
|
_pre_filter_len = len(df) |
|
|
if self.min_len is not None: |
|
|
df = df[df['raw_seq_len'] >= self.min_len] |
|
|
if self.max_len is not None: |
|
|
df = df[df['raw_seq_len'] <= self.max_len] |
|
|
if self.min_chains is not None: |
|
|
df = df[df['num_chains'] >= self.min_chains] |
|
|
if self.max_chains is not None: |
|
|
df = df[df['num_chains'] <= self.max_chains] |
|
|
if self.min_resolution is not None: |
|
|
df = df[df['resolution'] >= self.min_resolution] |
|
|
if self.max_resolution is not None: |
|
|
df = df[df['resolution'] <= self.max_resolution] |
|
|
if self.include_structure_method is not None: |
|
|
df = df[df['include_structure_method'].isin(self.include_structure_method)] |
|
|
if self.include_oligomeric_detail is not None: |
|
|
df = df[df['include_oligomeric_detail'].isin(self.include_oligomeric_detail)] |
|
|
|
|
|
print(f">>> Filter out {len(df)} samples out of {_pre_filter_len} by the metadata filter") |
|
|
return df |
|
|
|
|
|
|
|
|
class RandomAccessProteinDataset(torch.utils.data.Dataset): |
|
|
"""Random access to pickle protein objects of dataset. |
|
|
|
|
|
dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors']) |
|
|
|
|
|
Note that each value is a ndarray in shape (L, *), for example: |
|
|
'atom_positions': (L, 37, 3) |
|
|
""" |
|
|
def __init__(self, |
|
|
path_to_dataset: Union[Path, str], |
|
|
path_to_seq_embedding: Optional[Path] = None, |
|
|
metadata_filter: Optional[MetadataFilter] = None, |
|
|
training: bool = True, |
|
|
transform: Optional[ProteinFeatureTransform] = None, |
|
|
suffix: Optional[str] = '.pkl', |
|
|
accession_code_fillter: Optional[Sequence[str]] = None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
path_to_dataset = os.path.expanduser(path_to_dataset) |
|
|
suffix = suffix if suffix.startswith('.') else '.' + suffix |
|
|
assert suffix in ('.pkl', '.pdb'), f"Invalid suffix: {suffix}" |
|
|
|
|
|
if os.path.isfile(path_to_dataset): |
|
|
assert path_to_dataset.endswith('.csv'), f"Invalid file extension: {path_to_dataset} (have to be .csv)" |
|
|
self._df = pd.read_csv(path_to_dataset) |
|
|
self._df.sort_values('modeled_seq_len', ascending=False) |
|
|
if metadata_filter: |
|
|
self._df = metadata_filter(self._df) |
|
|
self._data = self._df['processed_complex_path'].tolist() |
|
|
elif os.path.isdir(path_to_dataset): |
|
|
self._data = sorted(glob(os.path.join(path_to_dataset, '*' + suffix))) |
|
|
assert len(self._data) > 0, f"No {suffix} file found in '{path_to_dataset}'" |
|
|
else: |
|
|
_pattern = path_to_dataset |
|
|
self._data = sorted(glob(_pattern)) |
|
|
assert len(self._data) > 0, f"No files found in '{_pattern}'" |
|
|
|
|
|
if accession_code_fillter and len(accession_code_fillter) > 0: |
|
|
self._data = [p for p in self._data |
|
|
if np.isin(os.path.splitext(os.path.basename(p))[0], accession_code_fillter) |
|
|
] |
|
|
|
|
|
self.data = np.asarray(self._data) |
|
|
self.path_to_seq_embedding = os.path.expanduser(path_to_seq_embedding) \ |
|
|
if path_to_seq_embedding is not None else None |
|
|
self.suffix = suffix |
|
|
self.transform = transform |
|
|
self.training = training |
|
|
|
|
|
|
|
|
@property |
|
|
def num_samples(self): |
|
|
return len(self.data) |
|
|
|
|
|
def len(self): |
|
|
return self.__len__() |
|
|
|
|
|
def __len__(self): |
|
|
return self.num_samples |
|
|
|
|
|
def get(self, idx): |
|
|
return self.__getitem__(idx) |
|
|
|
|
|
@lru_cache(maxsize=100) |
|
|
def __getitem__(self, idx): |
|
|
"""return single pyg.Data() instance |
|
|
""" |
|
|
data_path = self.data[idx] |
|
|
accession_code = os.path.splitext(os.path.basename(data_path))[0] |
|
|
|
|
|
if self.suffix == '.pkl': |
|
|
|
|
|
with open(data_path, 'rb') as f: |
|
|
data_object = pickle.load(f) |
|
|
elif self.suffix == '.pdb': |
|
|
|
|
|
with open(data_path, 'r') as f: |
|
|
pdb_string = f.read() |
|
|
data_object = protein.from_pdb_string(pdb_string).to_dict() |
|
|
|
|
|
|
|
|
if self.transform is not None: |
|
|
data_object = self.transform(data_object) |
|
|
|
|
|
|
|
|
if self.path_to_seq_embedding is not None: |
|
|
embed_dict = torch.load( |
|
|
os.path.join(self.path_to_seq_embedding, f"{accession_code}.pt") |
|
|
) |
|
|
data_object.update( |
|
|
{ |
|
|
'seq_emb': embed_dict['representations'][33].float(), |
|
|
} |
|
|
) |
|
|
|
|
|
data_object['accession_code'] = accession_code |
|
|
return data_object |
|
|
|
|
|
|
|
|
|
|
|
class PretrainPDBDataset(RandomAccessProteinDataset): |
|
|
def __init__(self, |
|
|
path_to_dataset: str, |
|
|
metadata_filter: MetadataFilter, |
|
|
transform: ProteinFeatureTransform, |
|
|
**kwargs, |
|
|
): |
|
|
super(PretrainPDBDataset, self).__init__(path_to_dataset=path_to_dataset, |
|
|
metadata_filter=metadata_filter, |
|
|
transform=transform, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
class SamplingPDBDataset(RandomAccessProteinDataset): |
|
|
def __init__(self, |
|
|
path_to_dataset: str, |
|
|
training: bool = False, |
|
|
suffix: str = '.pdb', |
|
|
transform: Optional[ProteinFeatureTransform] = None, |
|
|
accession_code_fillter: Optional[Sequence[str]] = None, |
|
|
): |
|
|
assert os.path.isdir(path_to_dataset), f"Invalid path (expected to be directory): {path_to_dataset}" |
|
|
super(SamplingPDBDataset, self).__init__(path_to_dataset=path_to_dataset, |
|
|
training=training, |
|
|
suffix=suffix, |
|
|
transform=transform, |
|
|
accession_code_fillter=accession_code_fillter, |
|
|
metadata_filter=None, |
|
|
) |
|
|
|