| import torch |
| from typing import Optional, Tuple |
|
|
|
|
| AA_SET = set('LAGVSERTIPDKQNFYMHWCXBUOZ*') |
| CODON_SET = set('aA@bB#$%rRnNdDcCeEqQ^G&ghHiIj+MmlJLkK(fFpPoO=szZwSXTtxWyYuvUV]})') |
| DNA_SET = set('ATCG') |
| RNA_SET = set('AUCG') |
| NONCANONICAL_AMINO_ACIDS = set('XBUOZ*') |
| AMINO_ACID_TO_HUMAN_CODON = { |
| 'A': 'GCC', |
| 'R': 'CGC', |
| 'N': 'AAC', |
| 'D': 'GAC', |
| 'C': 'TGC', |
| 'Q': 'CAG', |
| 'E': 'GAG', |
| 'G': 'GGC', |
| 'H': 'CAC', |
| 'I': 'ATC', |
| 'L': 'CTG', |
| 'K': 'AAG', |
| 'M': 'ATG', |
| 'F': 'TTC', |
| 'P': 'CCC', |
| 'S': 'AGC', |
| 'T': 'ACC', |
| 'W': 'TGG', |
| 'Y': 'TAC', |
| 'V': 'GTG', |
| } |
| NONCANONICAL_ALANINE_CODON = 'GCT' |
|
|
| AA_TO_CODON_TOKEN = { |
| 'A': 'A', |
| 'R': 'B', |
| 'N': 'N', |
| 'D': 'D', |
| 'C': 'C', |
| 'Q': 'Q', |
| 'E': 'E', |
| 'G': 'G', |
| 'H': 'H', |
| 'I': 'I', |
| 'L': 'L', |
| 'K': 'K', |
| 'M': '(', |
| 'F': 'F', |
| 'P': 'P', |
| 'S': 'S', |
| 'T': 'T', |
| 'W': 'W', |
| 'Y': 'Y', |
| 'V': 'V', |
| } |
| CODON_TO_AA = { |
| 'a':'A', |
| 'A':'A', |
| '@':'A', |
| 'b':'A', |
| 'B':'R', |
| '#':'R', |
| '$':'R', |
| '%':'R', |
| 'r':'R', |
| 'R':'R', |
| 'n':'N', |
| 'N':'N', |
| 'd':'D', |
| 'D':'D', |
| 'c':'C', |
| 'C':'C', |
| 'e':'E', |
| 'E':'E', |
| 'q':'Q', |
| 'Q':'Q', |
| '^':'G', |
| 'G':'G', |
| '&':'G', |
| 'g':'G', |
| 'h':'H', |
| 'H':'H', |
| 'i':'I', |
| 'I':'I', |
| 'j':'I', |
| '+':'L', |
| 'M':'L', |
| 'm':'L', |
| 'l':'L', |
| 'J':'L', |
| 'L':'L', |
| 'k':'K', |
| 'K':'K', |
| '(':'M', |
| 'f':'F', |
| 'F':'F', |
| 'p':'P', |
| 'P':'P', |
| 'o':'P', |
| 'O':'P', |
| '=':'S', |
| 's':'S', |
| 'z':'S', |
| 'Z':'S', |
| 'w':'S', |
| 'S':'S', |
| 'X':'S', |
| 'T':'T', |
| 't':'T', |
| 'x':'T', |
| 'W':'T', |
| 'y':'Y', |
| 'Y':'Y', |
| 'u':'V', |
| 'v':'V', |
| 'U':'V', |
| 'V':'V', |
| ']':'*', |
| '}':'*', |
| ')':'*', |
| } |
| DNA_CODON_TO_AA = { |
| 'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L', |
| 'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S', |
| 'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*', |
| 'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W', |
| 'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L', |
| 'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P', |
| 'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q', |
| 'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R', |
| 'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M', |
| 'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T', |
| 'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K', |
| 'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R', |
| 'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V', |
| 'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A', |
| 'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E', |
| 'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G', |
| } |
|
|
| RNA_CODON_TO_AA = { |
| codon.replace('T', 'U'): aa for codon, aa in DNA_CODON_TO_AA.items() |
| } |
|
|
|
|
|
|
| def pad_and_concatenate_dimer( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| a_mask: Optional[torch.Tensor] = None, |
| b_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Given two sequences A and B with masks, pad (if needed) and concatenate them. |
| """ |
| batch_size, L, d = A.size() |
| if a_mask is None: |
| a_mask = torch.ones(batch_size, L, device=A.device) |
| if b_mask is None: |
| b_mask = torch.ones(batch_size, L, device=A.device) |
| |
| max_len = max( |
| int(a_mask[i].sum().item() + b_mask[i].sum().item()) |
| for i in range(batch_size) |
| ) |
| combined = torch.zeros(batch_size, max_len, d, device=A.device) |
| combined_mask = torch.zeros(batch_size, max_len, device=A.device) |
| for i in range(batch_size): |
| a_len = int(a_mask[i].sum().item()) |
| b_len = int(b_mask[i].sum().item()) |
| combined[i, :a_len] = A[i, :a_len] |
| combined[i, a_len:a_len+b_len] = B[i, :b_len] |
| combined_mask[i, :a_len+b_len] = 1 |
| return combined, combined_mask |
|
|