| from types import SimpleNamespace |
|
|
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset |
|
|
|
|
| ModalityType = SimpleNamespace( |
| AA="aa", |
| DNA="dna", |
| PDB="pdb", |
| GO="go", |
| MSA="msa", |
| TEXT="text", |
| ) |
|
|
| class Normalize(nn.Module): |
| def __init__(self, dim: int) -> None: |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x): |
| return torch.nn.functional.normalize(x, dim=self.dim, p=2) |
|
|
| class EmbeddingDataset(Dataset): |
| """ |
| The main class for turning any modality to a torch Dataset that can be passed to |
| a torch dataloader. Any modality that doesn't fit into the __getitem__ |
| method can subclass this and modify the __getitem__ method. |
| """ |
| def __init__(self, sequence_file_path, embeddings_file_path, modality): |
| self.sequence = pd.read_csv(sequence_file_path) |
| self.embedding = torch.load(embeddings_file_path) |
| self.modality = modality |
|
|
| def __len__(self): |
| return len(self.sequence) |
|
|
| def __getitem__(self, idx): |
| sequence = self.sequence.iloc[idx, 0] |
| embedding = self.embedding[idx] |
| return {"aa": sequence, self.modality: embedding} |
|
|
| class DualEmbeddingDataset(Dataset): |
| """ |
| The main class for turning any modality to a torch Dataset that can be passed to |
| a torch dataloader. Any modality that doesn't fit into the __getitem__ |
| method can subclass this and modify the __getitem__ method. |
| """ |
| def __init__(self, sequence_embeddings_file_path, embeddings_file_path, modality): |
| self.sequence_embedding = torch.load(sequence_embeddings_file_path) |
| self.embedding = torch.load(embeddings_file_path) |
| self.modality = modality |
|
|
| def __len__(self): |
| return len(self.sequence_embedding) |
|
|
| def __getitem__(self, idx): |
| sequence_embedding = self.sequence_embedding[idx] |
| embedding = self.embedding[idx] |
| return {"aa": sequence_embedding, self.modality: embedding} |
| |
| class ProteinBindModel(nn.Module): |
|
|
| def __init__( |
| self, |
| aa_embed_dim, |
| dna_embed_dim, |
| pdb_embed_dim, |
| go_embed_dim, |
| msa_embed_dim, |
| text_embed_dim, |
| in_embed_dim, |
| out_embed_dim |
| ): |
| super().__init__() |
| self.modality_trunks = self._create_modality_trunk( |
| aa_embed_dim, |
| dna_embed_dim, |
| pdb_embed_dim, |
| go_embed_dim, |
| msa_embed_dim, |
| text_embed_dim, |
| out_embed_dim |
| ) |
| self.modality_heads = self._create_modality_head( |
| in_embed_dim, |
| out_embed_dim, |
| ) |
| self.modality_postprocessors = self._create_modality_postprocessors( |
| out_embed_dim |
| ) |
|
|
|
|
| def _create_modality_trunk( |
| self, |
| aa_embed_dim, |
| dna_embed_dim, |
| pdb_embed_dim, |
| go_embed_dim, |
| msa_embed_dim, |
| text_embed_dim, |
| in_embed_dim |
| ): |
| """ |
| The current layers are just a proof of concept |
| and are subject to the opinion of others. |
| :param aa_embed_dim: |
| :param dna_embed_dim: |
| :param pdb_embed_dim: |
| :param go_embed_dim: |
| :param msa_embed_dim: |
| :param text_embed_dim: |
| :param in_embed_dim: |
| :return: |
| """ |
| modality_trunks = {} |
|
|
| modality_trunks[ModalityType.AA] = nn.Sequential( |
| nn.Linear(aa_embed_dim, 512), |
| nn.ReLU(), |
| nn.Linear(512, 512), |
| nn.ReLU(), |
| nn.Linear(512, in_embed_dim), |
| ) |
|
|
| modality_trunks[ModalityType.DNA] = nn.Sequential( |
| nn.Linear(dna_embed_dim, 512), |
| nn.ReLU(), |
| nn.Linear(512, 512), |
| nn.ReLU(), |
| nn.Linear(512, in_embed_dim), |
| ) |
|
|
| modality_trunks[ModalityType.PDB] = nn.Sequential( |
| nn.Linear(pdb_embed_dim, 512), |
| nn.ReLU(), |
| nn.Linear(512, 512), |
| nn.ReLU(), |
| nn.Linear(512, in_embed_dim), |
| ) |
| |
| modality_trunks[ModalityType.GO] = nn.Sequential( |
| nn.Linear(go_embed_dim, 512), |
| nn.ReLU(), |
| nn.Linear(512, 512), |
| nn.ReLU(), |
| nn.Linear(512, in_embed_dim), |
| ) |
|
|
| modality_trunks[ModalityType.MSA] = nn.Sequential( |
| nn.Linear(msa_embed_dim, 512), |
| nn.ReLU(), |
| nn.Linear(512, 512), |
| nn.ReLU(), |
| nn.Linear(512, in_embed_dim), |
| ) |
|
|
| modality_trunks[ModalityType.TEXT] = nn.Sequential( |
| nn.Linear(text_embed_dim, 512), |
| nn.ReLU(), |
| nn.Linear(512, 512), |
| nn.ReLU(), |
| nn.Linear(512, in_embed_dim), |
| ) |
|
|
| return nn.ModuleDict(modality_trunks) |
|
|
| def _create_modality_head( |
| self, |
| in_embed_dim, |
| out_embed_dim |
| ): |
| modality_heads = {} |
|
|
| modality_heads[ModalityType.AA] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), |
| nn.Dropout(p=0.5), |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| modality_heads[ModalityType.DNA] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), |
| nn.Dropout(p=0.5), |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| modality_heads[ModalityType.PDB] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), |
| nn.Dropout(p=0.5), |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| modality_heads[ModalityType.GO] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), |
| nn.Dropout(p=0.5), |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| modality_heads[ModalityType.MSA] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), |
| nn.Dropout(p=0.5), |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), |
| ) |
|
|
| modality_heads[ModalityType.TEXT] = nn.Sequential( |
| nn.LayerNorm(normalized_shape=in_embed_dim, eps=1e-6), |
| nn.Dropout(p=0.5), |
| nn.Linear(in_embed_dim, out_embed_dim, bias=False), |
| ) |
| return nn.ModuleDict(modality_heads) |
|
|
| def _create_modality_postprocessors(self, out_embed_dim): |
| modality_postprocessors = {} |
| modality_postprocessors[ModalityType.AA] = Normalize(dim=-1) |
| modality_postprocessors[ModalityType.DNA] = Normalize(dim=-1) |
| modality_postprocessors[ModalityType.PDB] = Normalize(dim=-1) |
| modality_postprocessors[ModalityType.TEXT] = Normalize(dim=-1) |
| modality_postprocessors[ModalityType.GO] = Normalize(dim=-1) |
| modality_postprocessors[ModalityType.MSA] = Normalize(dim=-1) |
|
|
|
|
| return nn.ModuleDict(modality_postprocessors) |
|
|
| def forward(self, inputs): |
| """ |
| input = {k_1: [v],k_n: [v]} |
| for key in input |
| get trunk for key |
| forward pass of value in trunk |
| get projection head of key |
| forward pass of value in projection head |
| append output in output dict |
| return { k_1, [o], k_n: [o]} |
| """ |
|
|
| outputs = {} |
|
|
| for modality_key, modality_value in inputs.items(): |
|
|
|
|
| modality_value = self.modality_trunks[modality_key]( |
| modality_value |
| ) |
|
|
| modality_value = self.modality_heads[modality_key]( |
| modality_value |
| ) |
| |
| modality_value = self.modality_postprocessors[modality_key]( |
| modality_value |
| ) |
| outputs[modality_key] = modality_value |
|
|
| return outputs |
|
|
|
|
| def create_proteinbind(pretrained=False): |
| """ |
| The embedding dimensions here are dummy |
| :param pretrained: |
| :return: |
| """ |
| model = ProteinBindModel( |
| aa_embed_dim=480, |
| dna_embed_dim=1280, |
| pdb_embed_dim=128, |
| go_embed_dim=600, |
| msa_embed_dim=768, |
| text_embed_dim=768, |
| in_embed_dim=1024, |
| out_embed_dim=1024 |
| ) |
|
|
| if pretrained: |
| |
| PATH = 'best_model.pth' |
|
|
| model.load_state_dict(torch.load(PATH)) |
|
|
| return model |
|
|