Instructions to use Taykhoom/CodonBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Taykhoom/CodonBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="Taykhoom/CodonBERT", trust_remote_code=True)# Load model directly from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("Taykhoom/CodonBERT", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 3,707 Bytes
6509a75 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | import warnings
import numpy as np
from transformers import BertTokenizer
class CodonBertTokenizer(BertTokenizer):
"""BertTokenizer that auto-converts nucleotide sequences to codon-level tokens.
Raw nucleotide input is normalized (T->U, uppercase, whitespace stripped),
then split into non-overlapping 3-mer codons before vocab lookup. Trailing
1-2 nucleotides that do not form a complete codon are dropped.
eos_token is aliased to sep_token ("[SEP]") so that pooling code that
excludes both CLS and EOS/SEP positions works correctly.
Standard usage (raw nucleotides):
tokenizer("AUGAAAGGG")
tokenizer(["AUGAAAGGG", "AUGUUUCCC"], return_tensors="pt", padding=True)
CDS-aware usage (full mRNA + CDS track -> extract CDS, chunk, encode):
tokenizer.batch_encode_with_cds(
["NNNATGAAAGGGNN"],
cds=[np.array([0,0,0,1,0,0,1,0,0,1,0,0,0,0])],
return_tensors="pt",
padding=True,
)
Works with compare_minimal_vs_mm.py --use_cds out of the box.
"""
def __init__(self, *args, **kwargs):
kwargs.setdefault("eos_token", "[SEP]")
super().__init__(*args, **kwargs)
def _tokenize(self, text, split_special_tokens=False):
seq = "".join(text.split()).upper().replace("T", "U")
n = len(seq) - len(seq) % 3
return [seq[i:i + 3] for i in range(0, n, 3)]
@staticmethod
def _extract_cds(sequence, cds):
if sum(cds) == 0:
warnings.warn("No CDS found. Returning truncated sequence.")
n = len(sequence) - len(sequence) % 3
return sequence[:n]
first = int(np.argmax(cds == 1))
last = int(len(cds) - 1 - np.argmax(np.flip(cds) == 1)) + 2
proposed = sequence[first:last + 1]
if len(proposed) % 3 != 0:
warnings.warn("Irregular CDS. Returning truncated sequence.")
return proposed[:-(len(proposed) % 3)]
return proposed
def batch_encode_with_cds(self, sequences, cds_tracks, max_length=None, **kwargs):
"""Encode a batch of raw mRNA sequences using CDS-aware preprocessing.
Args:
sequences: List of raw nucleotide strings.
cds_tracks: List of numpy arrays (one per sequence). Non-zero values
mark the first nucleotide of each codon in the CDS region.
max_length: Max content codon-tokens per chunk (special tokens NOT
counted). Defaults to model_max_length - 2. This matches the
convention in compare_minimal_vs_mm.py where max_length is
already adjusted for special tokens.
**kwargs: Forwarded to batch_encode_plus (e.g. return_tensors, padding).
Returns:
(BatchEncoding, chunk_counts): chunk_counts[i] is the number of
chunks produced from sequence i.
"""
budget_codons = max_length or (self.model_max_length - 2)
budget_nt = budget_codons * 3
all_strings = []
chunk_counts = []
for seq, cds in zip(sequences, cds_tracks):
seq = seq.replace("T", "U").replace("t", "u").upper()
cds_seq = self._extract_cds(seq, np.asarray(cds))
n = len(cds_seq)
chunks = []
for i in range(0, max(n, 1), budget_nt):
chunk = cds_seq[i:i + budget_nt]
chunk = chunk[:len(chunk) - len(chunk) % 3]
if chunk:
chunks.append(chunk)
all_strings.extend(chunks or [""])
chunk_counts.append(len(chunks) or 1)
enc = self.batch_encode_plus(all_strings, **kwargs)
return enc, chunk_counts
|