Adapt tokenization_interns1.py to transformers>=5.0.0

#13
by Zhangyc02 - opened
Files changed (1) hide show
  1. tokenization_interns1.py +147 -9
tokenization_interns1.py CHANGED
@@ -25,24 +25,27 @@ import regex as re
25
  import sentencepiece as spm
26
  from collections import OrderedDict
27
 
28
- from transformers.tokenization_utils import PreTrainedTokenizer
29
  from transformers.tokenization_utils_base import AddedToken, TextInput
30
- from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
31
  from transformers.utils import logging
 
 
 
 
 
 
32
 
33
 
34
  logger = logging.get_logger(__name__)
35
 
36
  try:
37
- from rdkit import Chem
38
- from rdkit import RDLogger
39
 
40
  RDLogger.DisableLog("rdApp.error")
41
  RDLogger.DisableLog("rdApp.*")
42
  RDKIT_AVAILABLE = True
43
  except ImportError:
44
  logger.warning_once(
45
- f"If tokenization with SMILES formula is of necessity, please 'pip install RDKit' for better tokenization quality."
46
  )
47
  RDKIT_AVAILABLE = False
48
 
@@ -343,7 +346,48 @@ class SmilesCheckModule(InternS1CheckModuleMixin):
343
  return self.check_brackets(text)
344
 
345
 
346
- class InternS1Tokenizer(Qwen2Tokenizer):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  """
348
  Construct an InternS1 tokenizer. Based on byte-level Byte-Pair-Encoding.
349
 
@@ -408,6 +452,54 @@ class InternS1Tokenizer(Qwen2Tokenizer):
408
  split_special_tokens=False,
409
  **kwargs,
410
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  self.extra_tokenizer_start_mapping = {}
412
  self.extra_tokenizer_end_mapping = {}
413
  self._extra_special_tokens = []
@@ -460,6 +552,7 @@ class InternS1Tokenizer(Qwen2Tokenizer):
460
  pad_token=pad_token,
461
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
462
  split_special_tokens=split_special_tokens,
 
463
  **kwargs,
464
  )
465
 
@@ -497,6 +590,10 @@ class InternS1Tokenizer(Qwen2Tokenizer):
497
  """Overload method"""
498
  return self.vocab_size
499
 
 
 
 
 
500
  @property
501
  def logical_auto_tokens(self):
502
  """Tokens that won't be decoded and only for switching tokenizer"""
@@ -633,9 +730,6 @@ class InternS1Tokenizer(Qwen2Tokenizer):
633
 
634
  text, kwargs = self.prepare_for_tokenization(text, **kwargs)
635
 
636
- if kwargs:
637
- logger.warning(f"Keyword arguments {kwargs} not recognized.")
638
-
639
  if hasattr(self, "do_lower_case") and self.do_lower_case:
640
  # convert non-special tokens to lowercase. Might be super slow as well?
641
  escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
@@ -785,6 +879,7 @@ class InternS1Tokenizer(Qwen2Tokenizer):
785
  self._added_tokens_encoder[token.content] = token_index
786
  if self.verbose:
787
  logger.info(f"Adding {token} to the vocabulary")
 
788
  self._update_trie()
789
  self._update_total_vocab_size()
790
 
@@ -814,6 +909,49 @@ class InternS1Tokenizer(Qwen2Tokenizer):
814
  else:
815
  return self._bpe_tokenize(text)
816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
817
  def _bpe_tokenize(self, text, **kwargs):
818
  text = text.replace(
819
  "▁", " "
 
25
  import sentencepiece as spm
26
  from collections import OrderedDict
27
 
 
28
  from transformers.tokenization_utils_base import AddedToken, TextInput
 
29
  from transformers.utils import logging
30
+ import transformers
31
+ from packaging import version
32
+ if version.parse(transformers.__version__) >= version.parse("5.0.0"):
33
+ from transformers.tokenization_python import PreTrainedTokenizer
34
+ else:
35
+ from transformers.tokenization_utils import PreTrainedTokenizer
36
 
37
 
38
  logger = logging.get_logger(__name__)
39
 
40
  try:
41
+ from rdkit import Chem, RDLogger
 
42
 
43
  RDLogger.DisableLog("rdApp.error")
44
  RDLogger.DisableLog("rdApp.*")
45
  RDKIT_AVAILABLE = True
46
  except ImportError:
47
  logger.warning_once(
48
+ "If tokenization with SMILES formula is of necessity, please 'pip install RDKit' for better tokenization quality."
49
  )
50
  RDKIT_AVAILABLE = False
51
 
 
346
  return self.check_brackets(text)
347
 
348
 
349
+ @lru_cache
350
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
351
+ def bytes_to_unicode():
352
+ """
353
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
354
+ characters the bpe code barfs on.
355
+
356
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
357
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
358
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
359
+ tables between utf-8 bytes and unicode strings.
360
+ """
361
+ bs = (
362
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
363
+ )
364
+ cs = bs[:]
365
+ n = 0
366
+ for b in range(2**8):
367
+ if b not in bs:
368
+ bs.append(b)
369
+ cs.append(2**8 + n)
370
+ n += 1
371
+ cs = [chr(n) for n in cs]
372
+ return dict(zip(bs, cs))
373
+
374
+
375
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
376
+ def get_pairs(word):
377
+ """
378
+ Return set of symbol pairs in a word.
379
+
380
+ Word is represented as tuple of symbols (symbols being variable-length strings).
381
+ """
382
+ pairs = set()
383
+ prev_char = word[0]
384
+ for char in word[1:]:
385
+ pairs.add((prev_char, char))
386
+ prev_char = char
387
+ return pairs
388
+
389
+
390
+ class InternS1Tokenizer(PreTrainedTokenizer):
391
  """
392
  Construct an InternS1 tokenizer. Based on byte-level Byte-Pair-Encoding.
393
 
 
452
  split_special_tokens=False,
453
  **kwargs,
454
  ):
455
+ bos_token = (
456
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
457
+ if isinstance(bos_token, str)
458
+ else bos_token
459
+ )
460
+ eos_token = (
461
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
462
+ if isinstance(eos_token, str)
463
+ else eos_token
464
+ )
465
+ unk_token = (
466
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
467
+ if isinstance(unk_token, str)
468
+ else unk_token
469
+ )
470
+ pad_token = (
471
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
472
+ if isinstance(pad_token, str)
473
+ else pad_token
474
+ )
475
+
476
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
477
+ self.encoder = json.load(vocab_handle)
478
+ self.decoder = {v: k for k, v in self.encoder.items()}
479
+ self.errors = errors # how to handle errors in decoding
480
+ self.byte_encoder = bytes_to_unicode()
481
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
482
+ bpe_merges = []
483
+ with open(merges_file, encoding="utf-8") as merges_handle:
484
+ for i, line in enumerate(merges_handle):
485
+ line = line.strip()
486
+ if (i == 0 and line.startswith("#version:")) or not line:
487
+ continue
488
+ bpe_merges.append(tuple(line.split()))
489
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
490
+ # NOTE: the cache can grow without bound and will get really large for long running processes
491
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
492
+ # not a memory leak but appears as one.
493
+ # GPT2Tokenizer has the same problem, so let's be consistent.
494
+ self.cache = {}
495
+
496
+ self.pat = re.compile(PRETOKENIZE_REGEX)
497
+
498
+ if kwargs.get("add_prefix_space", False):
499
+ logger.warning_once(
500
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
501
+ )
502
+
503
  self.extra_tokenizer_start_mapping = {}
504
  self.extra_tokenizer_end_mapping = {}
505
  self._extra_special_tokens = []
 
552
  pad_token=pad_token,
553
  clean_up_tokenization_spaces=clean_up_tokenization_spaces,
554
  split_special_tokens=split_special_tokens,
555
+ special_tokens_pattern="none",
556
  **kwargs,
557
  )
558
 
 
590
  """Overload method"""
591
  return self.vocab_size
592
 
593
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
594
+ def get_vocab(self):
595
+ return dict(self.encoder, **self.added_tokens_encoder)
596
+
597
  @property
598
  def logical_auto_tokens(self):
599
  """Tokens that won't be decoded and only for switching tokenizer"""
 
730
 
731
  text, kwargs = self.prepare_for_tokenization(text, **kwargs)
732
 
 
 
 
733
  if hasattr(self, "do_lower_case") and self.do_lower_case:
734
  # convert non-special tokens to lowercase. Might be super slow as well?
735
  escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
 
879
  self._added_tokens_encoder[token.content] = token_index
880
  if self.verbose:
881
  logger.info(f"Adding {token} to the vocabulary")
882
+
883
  self._update_trie()
884
  self._update_total_vocab_size()
885
 
 
909
  else:
910
  return self._bpe_tokenize(text)
911
 
912
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
913
+ def bpe(self, token):
914
+ if token in self.cache:
915
+ return self.cache[token]
916
+ word = tuple(token)
917
+ pairs = get_pairs(word)
918
+
919
+ if not pairs:
920
+ return token
921
+
922
+ while True:
923
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
924
+ if bigram not in self.bpe_ranks:
925
+ break
926
+ first, second = bigram
927
+ new_word = []
928
+ i = 0
929
+ while i < len(word):
930
+ try:
931
+ j = word.index(first, i)
932
+ except ValueError:
933
+ new_word.extend(word[i:])
934
+ break
935
+ else:
936
+ new_word.extend(word[i:j])
937
+ i = j
938
+
939
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
940
+ new_word.append(first + second)
941
+ i += 2
942
+ else:
943
+ new_word.append(word[i])
944
+ i += 1
945
+ new_word = tuple(new_word)
946
+ word = new_word
947
+ if len(word) == 1:
948
+ break
949
+ else:
950
+ pairs = get_pairs(word)
951
+ word = " ".join(word)
952
+ self.cache[token] = word
953
+ return word
954
+
955
  def _bpe_tokenize(self, text, **kwargs):
956
  text = text.replace(
957
  "▁", " "