| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| from typing import Dict, List, Optional, Tuple, Any |
|
|
| from transformers import PreTrainedTokenizer |
|
|
|
|
| class BinaryLLMTokenizer(PreTrainedTokenizer): |
| model_input_names = ["input_ids", "attention_mask"] |
|
|
| TOKEN_RE = re.compile(r"^<U([0-9A-Fa-f]{4})>$") |
|
|
| def __init__( |
| self, |
| bos_token: str = "<BOS>", |
| eos_token: str = "<EOS>", |
| unk_token: str = "<UNK>", |
| pad_token: Optional[str] = None, |
| **kwargs: Any, |
| ): |
| |
| self._base_vocab_size = 65536 |
|
|
| |
| self._bos_id = 65536 |
| self._eos_id = 65537 |
|
|
| |
| self._unk_id = self._eos_id |
|
|
| self._bos_str = bos_token |
| self._eos_str = eos_token |
| self._unk_str = unk_token |
| self._pad_str = pad_token |
|
|
| super().__init__( |
| bos_token=bos_token, |
| eos_token=eos_token, |
| unk_token=unk_token, |
| pad_token=pad_token, |
| **kwargs, |
| ) |
|
|
| |
|
|
| @property |
| def vocab_size(self) -> int: |
| |
| return 65538 |
|
|
| def get_vocab(self) -> Dict[str, int]: |
| |
| v = { |
| self._bos_str: self._bos_id, |
| self._eos_str: self._eos_id, |
| self._unk_str: self._unk_id, |
| } |
| if self.pad_token is not None: |
| v[self.pad_token] = self._convert_token_to_id(self.pad_token) |
| return v |
|
|
| def _id_to_token_base(self, i: int) -> str: |
| return f"<U{i:04X}>" |
|
|
| |
|
|
| def _encode_to_base65536_big_endian(self, text: str) -> List[int]: |
| b = bytearray(text.encode("utf-8", errors="strict")) |
| if len(b) == 0: |
| return [0] |
|
|
| out: List[int] = [] |
| i = 0 |
| n = len(b) |
|
|
| while i + 1 < n: |
| |
| out.append((b[i] << 8) | b[i + 1]) |
| i += 2 |
|
|
| if i < n: |
| |
| out.append(int(b[i])) |
|
|
| return out |
|
|
| def _decode_from_base65536_big_endian(self, ids: List[int]) -> str: |
| bb = bytearray() |
| for x in ids: |
| xi = int(x) & 0xFFFFFFFF |
| if 0 <= xi <= 255: |
| bb.append(xi) |
| else: |
| bb.append((xi >> 8) & 0xFF) |
| bb.append(xi & 0xFF) |
| return bytes(bb).decode("utf-8", errors="replace") |
|
|
| |
|
|
| def _tokenize(self, text: str) -> List[str]: |
| ids = self._encode_to_base65536_big_endian(text) |
| return [self._id_to_token_base(i) for i in ids] |
|
|
| def _convert_token_to_id(self, token: str) -> int: |
| if token == self._bos_str: |
| return self._bos_id |
| if token == self._eos_str: |
| return self._eos_id |
| if token == self._unk_str: |
| return self._unk_id |
|
|
| if self.pad_token is not None and token == self.pad_token: |
| |
| if self.pad_token == self._eos_str: |
| return self._eos_id |
| return self._eos_id |
|
|
| m = self.TOKEN_RE.match(token) |
| if m: |
| return int(m.group(1), 16) |
|
|
| return self._unk_id |
|
|
| def _convert_id_to_token(self, index: int) -> str: |
| if index == self._bos_id: |
| return self._bos_str |
| if index == self._eos_id: |
| return self._eos_str |
| if index == self._unk_id: |
| return self._unk_str |
|
|
| if self.pad_token is not None and index == self.pad_token_id: |
| return self.pad_token |
|
|
| if 0 <= index < self._base_vocab_size: |
| return self._id_to_token_base(index) |
|
|
| return self._unk_str |
|
|
| def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| ids: List[int] = [] |
| for t in tokens: |
| if t in (self._bos_str, self._eos_str, self._unk_str): |
| continue |
| if self.pad_token is not None and t == self.pad_token: |
| continue |
| m = self.TOKEN_RE.match(t) |
| if m: |
| ids.append(int(m.group(1), 16)) |
| return self._decode_from_base65536_big_endian(ids) |
|
|
| def build_inputs_with_special_tokens( |
| self, |
| token_ids_0: List[int], |
| token_ids_1: Optional[List[int]] = None, |
| ) -> List[int]: |
| |
| |
| if token_ids_1 is None: |
| return [self._bos_id] + token_ids_0 + [self._eos_id] |
| return [self._bos_id] + token_ids_0 + [self._eos_id] + token_ids_1 + [self._eos_id] |
|
|
| def get_special_tokens_mask( |
| self, |
| token_ids_0: List[int], |
| token_ids_1: Optional[List[int]] = None, |
| already_has_special_tokens: bool = False, |
| ) -> List[int]: |
| pad_id = self.pad_token_id if self.pad_token is not None else -1 |
|
|
| if already_has_special_tokens: |
| return [ |
| 1 if t in (self._bos_id, self._eos_id, self._unk_id, pad_id) else 0 |
| for t in token_ids_0 |
| ] |
|
|
| if token_ids_1 is None: |
| return [1] + [0] * len(token_ids_0) + [1] |
| return [1] + [0] * len(token_ids_0) + [1] + [0] * len(token_ids_1) + [1] |
|
|
| def create_token_type_ids_from_sequences( |
| self, |
| token_ids_0: List[int], |
| token_ids_1: Optional[List[int]] = None, |
| ) -> List[int]: |
| if token_ids_1 is None: |
| return [0] * (len(token_ids_0) + 2) |
| return [0] * (len(token_ids_0) + len(token_ids_1) + 3) |
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
| if not os.path.isdir(save_directory): |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| name = (filename_prefix + "-" if filename_prefix else "") + "binaryllm_vocab.json" |
| path = os.path.join(save_directory, name) |
|
|
| data = { |
| "base_vocab_size": 65536, |
| "vocab_size": 65538, |
| "bos_token": self._bos_str, |
| "bos_token_id": self._bos_id, |
| "eos_token": self._eos_str, |
| "eos_token_id": self._eos_id, |
| "unk_token": self._unk_str, |
| "unk_token_id": self._unk_id, |
| "pad_token": self.pad_token, |
| "pad_token_id": self.pad_token_id, |
| "encoding": "utf-8", |
| "radix": 65536, |
| "endianness": "big", |
| "odd_length_rule": "last_byte_as_single_digit_0_255", |
| } |
|
|
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
| return (path,) |
|
|