Spaces:
Running
Running
| ''' | |
| This code is refer from: | |
| https://github.com/AlibabaResearch/AdvancedLiterateMachinery/blob/main/OCR/MGP-STR | |
| ''' | |
| import numpy as np | |
| from openrec.preprocess.ctc_label_encode import BaseRecLabelEncode | |
| class MGPLabelEncode(BaseRecLabelEncode): | |
| """ Convert between text-label and text-index """ | |
| SPACE = '[s]' | |
| GO = '[GO]' | |
| list_token = [GO, SPACE] | |
| def __init__(self, | |
| max_text_length, | |
| character_dict_path=None, | |
| use_space_char=False, | |
| only_char=False, | |
| **kwargs): | |
| super(MGPLabelEncode, | |
| self).__init__(max_text_length, character_dict_path, | |
| use_space_char) | |
| # character (str): set of the possible characters. | |
| # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. | |
| self.batch_max_length = max_text_length + len(self.list_token) | |
| self.only_char = only_char | |
| if not only_char: | |
| # transformers==4.2.1 | |
| from transformers import BertTokenizer, GPT2Tokenizer | |
| self.bpe_tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
| self.wp_tokenizer = BertTokenizer.from_pretrained( | |
| 'bert-base-uncased') | |
| def __call__(self, data): | |
| text = data['label'] | |
| char_text, char_len = self.encode(text) | |
| if char_text is None: | |
| return None | |
| data['length'] = np.array(char_len) | |
| data['char_label'] = np.array(char_text) | |
| if self.only_char: | |
| return data | |
| bpe_text = self.bpe_encode(text) | |
| if bpe_text is None: | |
| return None | |
| wp_text = self.wp_encode(text) | |
| data['bpe_label'] = np.array(bpe_text) | |
| data['wp_label'] = wp_text | |
| return data | |
| def add_special_char(self, dict_character): | |
| dict_character = self.list_token + dict_character | |
| return dict_character | |
| def encode(self, text): | |
| """ convert text-label into text-index. | |
| """ | |
| if len(text) == 0: | |
| return None, None | |
| if self.lower: | |
| text = text.lower() | |
| length = len(text) | |
| text = [self.GO] + list(text) + [self.SPACE] | |
| text_list = [] | |
| for char in text: | |
| if char not in self.dict: | |
| continue | |
| text_list.append(self.dict[char]) | |
| if len(text_list) == 0 or len(text_list) > self.batch_max_length: | |
| return None, None | |
| text_list = text_list + [self.dict[self.GO] | |
| ] * (self.batch_max_length - len(text_list)) | |
| return text_list, length | |
| def bpe_encode(self, text): | |
| if len(text) == 0: | |
| return None | |
| token = self.bpe_tokenizer(text)['input_ids'] | |
| text_list = [1] + token + [2] | |
| if len(text_list) == 0 or len(text_list) > self.batch_max_length: | |
| return None | |
| text_list = text_list + [self.dict[self.GO] | |
| ] * (self.batch_max_length - len(text_list)) | |
| return text_list | |
| def wp_encode(self, text): | |
| wp_target = self.wp_tokenizer([text], | |
| padding='max_length', | |
| max_length=self.batch_max_length, | |
| truncation=True, | |
| return_tensors='np') | |
| return wp_target['input_ids'][0] | |