File size: 3,707 Bytes
6509a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import warnings
import numpy as np
from transformers import BertTokenizer


class CodonBertTokenizer(BertTokenizer):
    """BertTokenizer that auto-converts nucleotide sequences to codon-level tokens.

    Raw nucleotide input is normalized (T->U, uppercase, whitespace stripped),
    then split into non-overlapping 3-mer codons before vocab lookup. Trailing
    1-2 nucleotides that do not form a complete codon are dropped.

    eos_token is aliased to sep_token ("[SEP]") so that pooling code that
    excludes both CLS and EOS/SEP positions works correctly.

    Standard usage (raw nucleotides):
        tokenizer("AUGAAAGGG")
        tokenizer(["AUGAAAGGG", "AUGUUUCCC"], return_tensors="pt", padding=True)

    CDS-aware usage (full mRNA + CDS track -> extract CDS, chunk, encode):
        tokenizer.batch_encode_with_cds(
            ["NNNATGAAAGGGNN"],
            cds=[np.array([0,0,0,1,0,0,1,0,0,1,0,0,0,0])],
            return_tensors="pt",
            padding=True,
        )

    Works with compare_minimal_vs_mm.py --use_cds out of the box.
    """

    def __init__(self, *args, **kwargs):
        kwargs.setdefault("eos_token", "[SEP]")
        super().__init__(*args, **kwargs)

    def _tokenize(self, text, split_special_tokens=False):
        seq = "".join(text.split()).upper().replace("T", "U")
        n = len(seq) - len(seq) % 3
        return [seq[i:i + 3] for i in range(0, n, 3)]

    @staticmethod
    def _extract_cds(sequence, cds):
        if sum(cds) == 0:
            warnings.warn("No CDS found. Returning truncated sequence.")
            n = len(sequence) - len(sequence) % 3
            return sequence[:n]
        first = int(np.argmax(cds == 1))
        last = int(len(cds) - 1 - np.argmax(np.flip(cds) == 1)) + 2
        proposed = sequence[first:last + 1]
        if len(proposed) % 3 != 0:
            warnings.warn("Irregular CDS. Returning truncated sequence.")
            return proposed[:-(len(proposed) % 3)]
        return proposed

    def batch_encode_with_cds(self, sequences, cds_tracks, max_length=None, **kwargs):
        """Encode a batch of raw mRNA sequences using CDS-aware preprocessing.

        Args:
            sequences: List of raw nucleotide strings.
            cds_tracks: List of numpy arrays (one per sequence). Non-zero values
                mark the first nucleotide of each codon in the CDS region.
            max_length: Max content codon-tokens per chunk (special tokens NOT
                counted). Defaults to model_max_length - 2. This matches the
                convention in compare_minimal_vs_mm.py where max_length is
                already adjusted for special tokens.
            **kwargs: Forwarded to batch_encode_plus (e.g. return_tensors, padding).

        Returns:
            (BatchEncoding, chunk_counts): chunk_counts[i] is the number of
            chunks produced from sequence i.
        """
        budget_codons = max_length or (self.model_max_length - 2)
        budget_nt = budget_codons * 3

        all_strings = []
        chunk_counts = []

        for seq, cds in zip(sequences, cds_tracks):
            seq = seq.replace("T", "U").replace("t", "u").upper()
            cds_seq = self._extract_cds(seq, np.asarray(cds))
            n = len(cds_seq)
            chunks = []
            for i in range(0, max(n, 1), budget_nt):
                chunk = cds_seq[i:i + budget_nt]
                chunk = chunk[:len(chunk) - len(chunk) % 3]
                if chunk:
                    chunks.append(chunk)
            all_strings.extend(chunks or [""])
            chunk_counts.append(len(chunks) or 1)

        enc = self.batch_encode_plus(all_strings, **kwargs)
        return enc, chunk_counts