Instructions to use Taykhoom/CodonBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Taykhoom/CodonBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="Taykhoom/CodonBERT", trust_remote_code=True)# Load model directly from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("Taykhoom/CodonBERT", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import warnings | |
| import numpy as np | |
| from transformers import BertTokenizer | |
| class CodonBertTokenizer(BertTokenizer): | |
| """BertTokenizer that auto-converts nucleotide sequences to codon-level tokens. | |
| Raw nucleotide input is normalized (T->U, uppercase, whitespace stripped), | |
| then split into non-overlapping 3-mer codons before vocab lookup. Trailing | |
| 1-2 nucleotides that do not form a complete codon are dropped. | |
| eos_token is aliased to sep_token ("[SEP]") so that pooling code that | |
| excludes both CLS and EOS/SEP positions works correctly. | |
| Standard usage (raw nucleotides): | |
| tokenizer("AUGAAAGGG") | |
| tokenizer(["AUGAAAGGG", "AUGUUUCCC"], return_tensors="pt", padding=True) | |
| CDS-aware usage (full mRNA + CDS track -> extract CDS, chunk, encode): | |
| tokenizer.batch_encode_with_cds( | |
| ["NNNATGAAAGGGNN"], | |
| cds=[np.array([0,0,0,1,0,0,1,0,0,1,0,0,0,0])], | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| Works with compare_minimal_vs_mm.py --use_cds out of the box. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| kwargs.setdefault("eos_token", "[SEP]") | |
| super().__init__(*args, **kwargs) | |
| def _tokenize(self, text, split_special_tokens=False): | |
| seq = "".join(text.split()).upper().replace("T", "U") | |
| n = len(seq) - len(seq) % 3 | |
| return [seq[i:i + 3] for i in range(0, n, 3)] | |
| def _extract_cds(sequence, cds): | |
| if sum(cds) == 0: | |
| warnings.warn("No CDS found. Returning truncated sequence.") | |
| n = len(sequence) - len(sequence) % 3 | |
| return sequence[:n] | |
| first = int(np.argmax(cds == 1)) | |
| last = int(len(cds) - 1 - np.argmax(np.flip(cds) == 1)) + 2 | |
| proposed = sequence[first:last + 1] | |
| if len(proposed) % 3 != 0: | |
| warnings.warn("Irregular CDS. Returning truncated sequence.") | |
| return proposed[:-(len(proposed) % 3)] | |
| return proposed | |
| def batch_encode_with_cds(self, sequences, cds_tracks, max_length=None, **kwargs): | |
| """Encode a batch of raw mRNA sequences using CDS-aware preprocessing. | |
| Args: | |
| sequences: List of raw nucleotide strings. | |
| cds_tracks: List of numpy arrays (one per sequence). Non-zero values | |
| mark the first nucleotide of each codon in the CDS region. | |
| max_length: Max content codon-tokens per chunk (special tokens NOT | |
| counted). Defaults to model_max_length - 2. This matches the | |
| convention in compare_minimal_vs_mm.py where max_length is | |
| already adjusted for special tokens. | |
| **kwargs: Forwarded to batch_encode_plus (e.g. return_tensors, padding). | |
| Returns: | |
| (BatchEncoding, chunk_counts): chunk_counts[i] is the number of | |
| chunks produced from sequence i. | |
| """ | |
| budget_codons = max_length or (self.model_max_length - 2) | |
| budget_nt = budget_codons * 3 | |
| all_strings = [] | |
| chunk_counts = [] | |
| for seq, cds in zip(sequences, cds_tracks): | |
| seq = seq.replace("T", "U").replace("t", "u").upper() | |
| cds_seq = self._extract_cds(seq, np.asarray(cds)) | |
| n = len(cds_seq) | |
| chunks = [] | |
| for i in range(0, max(n, 1), budget_nt): | |
| chunk = cds_seq[i:i + budget_nt] | |
| chunk = chunk[:len(chunk) - len(chunk) % 3] | |
| if chunk: | |
| chunks.append(chunk) | |
| all_strings.extend(chunks or [""]) | |
| chunk_counts.append(len(chunks) or 1) | |
| enc = self.batch_encode_plus(all_strings, **kwargs) | |
| return enc, chunk_counts | |