import warnings import numpy as np from transformers import BertTokenizer class CodonBertTokenizer(BertTokenizer): """BertTokenizer that auto-converts nucleotide sequences to codon-level tokens. Raw nucleotide input is normalized (T->U, uppercase, whitespace stripped), then split into non-overlapping 3-mer codons before vocab lookup. Trailing 1-2 nucleotides that do not form a complete codon are dropped. eos_token is aliased to sep_token ("[SEP]") so that pooling code that excludes both CLS and EOS/SEP positions works correctly. Standard usage (raw nucleotides): tokenizer("AUGAAAGGG") tokenizer(["AUGAAAGGG", "AUGUUUCCC"], return_tensors="pt", padding=True) CDS-aware usage (full mRNA + CDS track -> extract CDS, chunk, encode): tokenizer.batch_encode_with_cds( ["NNNATGAAAGGGNN"], cds=[np.array([0,0,0,1,0,0,1,0,0,1,0,0,0,0])], return_tensors="pt", padding=True, ) Works with compare_minimal_vs_mm.py --use_cds out of the box. """ def __init__(self, *args, **kwargs): kwargs.setdefault("eos_token", "[SEP]") super().__init__(*args, **kwargs) def _tokenize(self, text, split_special_tokens=False): seq = "".join(text.split()).upper().replace("T", "U") n = len(seq) - len(seq) % 3 return [seq[i:i + 3] for i in range(0, n, 3)] @staticmethod def _extract_cds(sequence, cds): if sum(cds) == 0: warnings.warn("No CDS found. Returning truncated sequence.") n = len(sequence) - len(sequence) % 3 return sequence[:n] first = int(np.argmax(cds == 1)) last = int(len(cds) - 1 - np.argmax(np.flip(cds) == 1)) + 2 proposed = sequence[first:last + 1] if len(proposed) % 3 != 0: warnings.warn("Irregular CDS. Returning truncated sequence.") return proposed[:-(len(proposed) % 3)] return proposed def batch_encode_with_cds(self, sequences, cds_tracks, max_length=None, **kwargs): """Encode a batch of raw mRNA sequences using CDS-aware preprocessing. Args: sequences: List of raw nucleotide strings. cds_tracks: List of numpy arrays (one per sequence). Non-zero values mark the first nucleotide of each codon in the CDS region. max_length: Max content codon-tokens per chunk (special tokens NOT counted). Defaults to model_max_length - 2. This matches the convention in compare_minimal_vs_mm.py where max_length is already adjusted for special tokens. **kwargs: Forwarded to batch_encode_plus (e.g. return_tensors, padding). Returns: (BatchEncoding, chunk_counts): chunk_counts[i] is the number of chunks produced from sequence i. """ budget_codons = max_length or (self.model_max_length - 2) budget_nt = budget_codons * 3 all_strings = [] chunk_counts = [] for seq, cds in zip(sequences, cds_tracks): seq = seq.replace("T", "U").replace("t", "u").upper() cds_seq = self._extract_cds(seq, np.asarray(cds)) n = len(cds_seq) chunks = [] for i in range(0, max(n, 1), budget_nt): chunk = cds_seq[i:i + budget_nt] chunk = chunk[:len(chunk) - len(chunk) % 3] if chunk: chunks.append(chunk) all_strings.extend(chunks or [""]) chunk_counts.append(len(chunks) or 1) enc = self.batch_encode_plus(all_strings, **kwargs) return enc, chunk_counts