# Copyright (c) Antoine Nzeyimana. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from __future__ import print_function, division import math import random import sys # Ignore warnings import warnings from typing import List import numpy as np import progressbar import torch import youtokentome as yttm from torch.utils.data import Dataset from kinyabert_utils import time_now warnings.filterwarnings("ignore") def read_corpus(fn): f = open(fn, 'r+') corpus_lines = [line.rstrip('\n') for line in f] f.close() corpus_lines.append("\n") return corpus_lines def read_vocab_idx(fn, voc, voc_idx): f = open(fn, 'r') v_lines = f.readlines() f.close() for v in v_lines: if(len(v)>1): vt = v.split('\t') if(len(vt) == 2): voc_idx[int(vt[1])] = vt[0] voc[vt[0]] = int(vt[1]) return voc, voc_idx def read_vocab_counts(fn, voc, voc_idx_counts): f = open(fn, 'r') v_lines = f.readlines() f.close() for v in v_lines: if(len(v)>1): vt = v.split('\t') if(len(vt) == 2): id = voc[vt[0]] voc_idx_counts[id] = int(vt[1]) return voc, voc_idx_counts def reduce_stem_vocab(_stem_vocab, _stem_vocab_idx, _stem_vocab_idx_counts, noun_min_count, verb_min_count, np_min_count, other_morpho_min_count, other_cls_min_count, other_token_min_count): reduced_stem_vocab = dict() mapped_stem_vocab_idx = dict() reduced_stem_vocab_idx_counts = dict() unk_idx = _stem_vocab[''] for i in range(1,len(_stem_vocab_idx)+1): if (i <= unk_idx): mapped_stem_vocab_idx[i] = i reduced_stem_vocab[_stem_vocab_idx[i]] = i reduced_stem_vocab_idx_counts[i] = 1000000 else: key = _stem_vocab_idx[i] if (key.startswith('N:')): min_count = noun_min_count elif (key.startswith('V:')): min_count = verb_min_count elif (key.startswith('QA:') or key.startswith('PO:') or key.startswith('DE:') or key.startswith('NU:') or key.startswith('OT:')): min_count = other_morpho_min_count elif (key.startswith('NP:')): min_count = np_min_count elif (key.startswith('T:')): min_count = other_token_min_count else: min_count = other_cls_min_count if _stem_vocab_idx_counts[i] >= min_count: idx = len(reduced_stem_vocab)+1 mapped_stem_vocab_idx[i] = idx reduced_stem_vocab[_stem_vocab_idx[i]] = idx reduced_stem_vocab_idx_counts[idx] = _stem_vocab_idx_counts[i] else: mapped_stem_vocab_idx[i] = reduced_stem_vocab[''] return reduced_stem_vocab, mapped_stem_vocab_idx, reduced_stem_vocab_idx_counts def sigmoid_score(x, min_val, max_val): s = pow((1.0 + (math.exp(((-8.0) * (x - min_val)) / (max_val - min_val)))), -8.0) return s class KBVocab: def __init__(self, noun_min_count = 200, verb_min_count = 100, np_min_count = 200, other_morpho_min_count = 20, other_cls_min_count = 20, other_token_min_count = 200, pos_tag_vocab_idx_tsv="data/pos_tag_vocab_idx_2021-02-07.tsv", pos_tag_vocab_tsv="data/pos_tag_vocab_2021-02-07.tsv", stem_vocab_idx_tsv="data/stem_vocab_idx_2021-02-07.tsv", stem_vocab_tsv="data/stem_vocab_2021-02-07.tsv", morpheme_slot_vocab_idx_tsv="data/morpheme_slot_vocab_idx_2021-02-07.tsv", morpheme_slot_vocab_tsv="data/morpheme_slot_vocab_2021-02-07.tsv", affix_vocab_idx_tsv="data/affix_vocab_idx_2021-02-07.tsv", affix_vocab_tsv="data/affix_vocab_2021-02-07.tsv", read_vocab_files = False): self.pos_tag_vocab = dict() self.pos_tag_vocab_idx = dict() self.pos_tag_vocab_idx_counts = dict() self.pos_tag_vocab_idx_subsample_weights = dict() self.pos_tag_vocab_idx_subsample_weights_max = 0.1 self.pos_tag_vocab_idx_subsample_weights_min = 0.1 if(read_vocab_files): read_vocab_idx(pos_tag_vocab_idx_tsv, self.pos_tag_vocab, self.pos_tag_vocab_idx) read_vocab_counts(pos_tag_vocab_tsv, self.pos_tag_vocab, self.pos_tag_vocab_idx_counts) self._stem_vocab = dict() self._stem_vocab_idx = dict() self._stem_vocab_idx_counts = dict() if(read_vocab_files): read_vocab_idx(stem_vocab_idx_tsv, self._stem_vocab, self._stem_vocab_idx) read_vocab_counts(stem_vocab_tsv, self._stem_vocab, self._stem_vocab_idx_counts) self.reduced_stem_vocab = dict() self.reduced_stem_vocab_idx = dict() self.mapped_stem_vocab_idx = dict() self.reduced_stem_vocab_idx_counts = dict() self.reduced_stem_vocab_idx_subsample_weights = dict() self.reduced_stem_vocab_idx_subsample_weights_max = 0.1 self.reduced_stem_vocab_idx_subsample_weights_min = 0.1 if(read_vocab_files): (self.reduced_stem_vocab, self.mapped_stem_vocab_idx, self.reduced_stem_vocab_idx_counts) = reduce_stem_vocab(self._stem_vocab, self._stem_vocab_idx, self._stem_vocab_idx_counts, noun_min_count, verb_min_count, np_min_count, other_morpho_min_count, other_cls_min_count, other_token_min_count) for k in self.reduced_stem_vocab: self.reduced_stem_vocab_idx[self.reduced_stem_vocab[k]] = k self.morpheme_slot_vocab = dict() self.morpheme_slot_vocab_idx = dict() self.morpheme_slot_vocab_idx_counts = dict() self.morpheme_slot_vocab_idx_subsample_weights = dict() self.morpheme_slot_vocab_idx_subsample_weights_max = 0.1 self.morpheme_slot_vocab_idx_subsample_weights_min = 0.1 if(read_vocab_files): read_vocab_idx(morpheme_slot_vocab_idx_tsv, self.morpheme_slot_vocab, self.morpheme_slot_vocab_idx) read_vocab_counts(morpheme_slot_vocab_tsv, self.morpheme_slot_vocab, self.morpheme_slot_vocab_idx_counts) self.affix_vocab = dict() self.affix_vocab_idx = dict() self.affix_vocab_idx_counts = dict() self.affix_vocab_idx_subsample_weights = dict() self.affix_vocab_idx_subsample_weights_max = 0.1 self.affix_vocab_idx_subsample_weights_min = 0.1 if(read_vocab_files): read_vocab_idx(affix_vocab_idx_tsv, self.affix_vocab, self.affix_vocab_idx) read_vocab_counts(affix_vocab_tsv, self.affix_vocab, self.affix_vocab_idx_counts) def compute_subsampling_weights(self): eps = 1e-5 sum_pos_tags = sum([self.pos_tag_vocab_idx_counts[k] for k in self.pos_tag_vocab_idx_counts]) for k in self.pos_tag_vocab_idx_counts: z = float(self.pos_tag_vocab_idx_counts[k])/float(sum_pos_tags) z = z + eps self.pos_tag_vocab_idx_subsample_weights[k] = (math.sqrt(z/0.001)+1.0)*(0.001/z) self.pos_tag_vocab_idx_subsample_weights_max = max([self.pos_tag_vocab_idx_subsample_weights[k] for k in self.pos_tag_vocab_idx_subsample_weights]) self.pos_tag_vocab_idx_subsample_weights_min = min([self.pos_tag_vocab_idx_subsample_weights[k] for k in self.pos_tag_vocab_idx_subsample_weights]) for k in self.pos_tag_vocab_idx_subsample_weights: if (k <= self.pos_tag_vocab['']): self.pos_tag_vocab_idx_subsample_weights[k] = self.pos_tag_vocab_idx_subsample_weights_min sum_morpheme_slots = sum([self.morpheme_slot_vocab_idx_counts[k] for k in self.morpheme_slot_vocab_idx_counts]) for k in self.morpheme_slot_vocab_idx_counts: z = float(self.morpheme_slot_vocab_idx_counts[k])/float(sum_morpheme_slots) z = z + eps self.morpheme_slot_vocab_idx_subsample_weights[k] = (math.sqrt(z/0.001)+1.0)*(0.001/z) self.morpheme_slot_vocab_idx_subsample_weights_max = max([self.morpheme_slot_vocab_idx_subsample_weights[k] for k in self.morpheme_slot_vocab_idx_subsample_weights]) self.morpheme_slot_vocab_idx_subsample_weights_min = min([self.morpheme_slot_vocab_idx_subsample_weights[k] for k in self.morpheme_slot_vocab_idx_subsample_weights]) for k in self.morpheme_slot_vocab_idx_subsample_weights: if (k <= self.morpheme_slot_vocab['']): self.morpheme_slot_vocab_idx_subsample_weights[k] = self.morpheme_slot_vocab_idx_subsample_weights_min sum_affixes = sum([self.affix_vocab_idx_counts[k] for k in self.affix_vocab_idx_counts]) for k in self.affix_vocab_idx_counts: z = float(self.affix_vocab_idx_counts[k])/float(sum_affixes) z = z + eps self.affix_vocab_idx_subsample_weights[k] = (math.sqrt(z/0.001)+1.0)*(0.001/z) self.affix_vocab_idx_subsample_weights_max = max([self.affix_vocab_idx_subsample_weights[k] for k in self.affix_vocab_idx_subsample_weights]) self.affix_vocab_idx_subsample_weights_min = min([self.affix_vocab_idx_subsample_weights[k] for k in self.affix_vocab_idx_subsample_weights]) for k in self.affix_vocab_idx_subsample_weights: if (k <= self.affix_vocab['']): self.affix_vocab_idx_subsample_weights[k] = self.affix_vocab_idx_subsample_weights_min sum_stems = sum([self.reduced_stem_vocab_idx_counts[k] for k in self.reduced_stem_vocab_idx_counts]) for k in self.reduced_stem_vocab_idx_counts: z = float(self.reduced_stem_vocab_idx_counts[k])/float(sum_stems) z = z + eps self.reduced_stem_vocab_idx_subsample_weights[k] = (math.sqrt(z/0.001)+1.0)*(0.001/z) self.reduced_stem_vocab_idx_subsample_weights_max = max([self.reduced_stem_vocab_idx_subsample_weights[k] for k in self.reduced_stem_vocab_idx_subsample_weights]) self.reduced_stem_vocab_idx_subsample_weights_min = min([self.reduced_stem_vocab_idx_subsample_weights[k] for k in self.reduced_stem_vocab_idx_subsample_weights]) for k in self.reduced_stem_vocab_idx_subsample_weights: if (k <= self.reduced_stem_vocab['']): self.reduced_stem_vocab_idx_subsample_weights[k] = self.reduced_stem_vocab_idx_subsample_weights_min def state_dict(self): return {'pos_tag_vocab':self.pos_tag_vocab, 'pos_tag_vocab_idx':self.pos_tag_vocab_idx, 'pos_tag_vocab_idx_counts':self.pos_tag_vocab_idx_counts, '_stem_vocab':self._stem_vocab, '_stem_vocab_idx':self._stem_vocab_idx, '_stem_vocab_idx_counts':self._stem_vocab_idx_counts, 'reduced_stem_vocab':self.reduced_stem_vocab, 'mapped_stem_vocab_idx':self.mapped_stem_vocab_idx, 'reduced_stem_vocab_idx_counts':self.reduced_stem_vocab_idx_counts, 'morpheme_slot_vocab':self.morpheme_slot_vocab, 'morpheme_slot_vocab_idx':self.morpheme_slot_vocab_idx, 'morpheme_slot_vocab_idx_counts':self.morpheme_slot_vocab_idx_counts, 'affix_vocab':self.affix_vocab, 'affix_vocab_idx':self.affix_vocab_idx, 'affix_vocab_idx_counts':self.affix_vocab_idx_counts, 'morpheme_slot_vocab_idx_subsample_weights': self.morpheme_slot_vocab_idx_subsample_weights, 'morpheme_slot_vocab_idx_subsample_weights_max': self.morpheme_slot_vocab_idx_subsample_weights_max, 'morpheme_slot_vocab_idx_subsample_weights_min': self.morpheme_slot_vocab_idx_subsample_weights_min, 'affix_vocab_idx_subsample_weights':self.affix_vocab_idx_subsample_weights, 'affix_vocab_idx_subsample_weights_max':self.affix_vocab_idx_subsample_weights_max, 'affix_vocab_idx_subsample_weights_min':self.affix_vocab_idx_subsample_weights_min, 'pos_tag_vocab_idx_subsample_weights': self.pos_tag_vocab_idx_subsample_weights, 'pos_tag_vocab_idx_subsample_weights_max': self.pos_tag_vocab_idx_subsample_weights_max, 'pos_tag_vocab_idx_subsample_weights_min': self.pos_tag_vocab_idx_subsample_weights_min, 'reduced_stem_vocab_idx_subsample_weights':self.reduced_stem_vocab_idx_subsample_weights, 'reduced_stem_vocab_idx_subsample_weights_max':self.reduced_stem_vocab_idx_subsample_weights_max, 'reduced_stem_vocab_idx_subsample_weights_min':self.reduced_stem_vocab_idx_subsample_weights_min} def load_state_dict(self, d): self.pos_tag_vocab = d['pos_tag_vocab'] self.pos_tag_vocab_idx = d['pos_tag_vocab_idx'] self.pos_tag_vocab_idx_counts = d['pos_tag_vocab_idx_counts'] self._stem_vocab = d['_stem_vocab'] self._stem_vocab_idx = d['_stem_vocab_idx'] self._stem_vocab_idx_counts = d['_stem_vocab_idx_counts'] self.reduced_stem_vocab = d['reduced_stem_vocab'] self.mapped_stem_vocab_idx = d['mapped_stem_vocab_idx'] self.reduced_stem_vocab_idx_counts = d['reduced_stem_vocab_idx_counts'] self.morpheme_slot_vocab = d['morpheme_slot_vocab'] self.morpheme_slot_vocab_idx = d['morpheme_slot_vocab_idx'] self.morpheme_slot_vocab_idx_counts = d['morpheme_slot_vocab_idx_counts'] self.affix_vocab = d['affix_vocab'] self.affix_vocab_idx = d['affix_vocab_idx'] self.affix_vocab_idx_counts = d['affix_vocab_idx_counts'] self.morpheme_slot_vocab_idx_subsample_weights = d['morpheme_slot_vocab_idx_subsample_weights'] self.morpheme_slot_vocab_idx_subsample_weights_max = d['morpheme_slot_vocab_idx_subsample_weights_max'] self.morpheme_slot_vocab_idx_subsample_weights_min = d['morpheme_slot_vocab_idx_subsample_weights_min'] self.pos_tag_vocab_idx_subsample_weights = d['pos_tag_vocab_idx_subsample_weights'] self.pos_tag_vocab_idx_subsample_weights_max = d['pos_tag_vocab_idx_subsample_weights_max'] self.pos_tag_vocab_idx_subsample_weights_min = d['pos_tag_vocab_idx_subsample_weights_min'] self.affix_vocab_idx_subsample_weights = d['affix_vocab_idx_subsample_weights'] self.affix_vocab_idx_subsample_weights_max = d['affix_vocab_idx_subsample_weights_max'] self.affix_vocab_idx_subsample_weights_min = d['affix_vocab_idx_subsample_weights_min'] self.reduced_stem_vocab_idx_subsample_weights = d['reduced_stem_vocab_idx_subsample_weights'] self.reduced_stem_vocab_idx_subsample_weights_max = d['reduced_stem_vocab_idx_subsample_weights_max'] self.reduced_stem_vocab_idx_subsample_weights_min = d['reduced_stem_vocab_idx_subsample_weights_min'] for k in self.reduced_stem_vocab: self.reduced_stem_vocab_idx[self.reduced_stem_vocab[k]] = k self.reduced_stem_vocab_idx_counts = dict() for i in self._stem_vocab_idx_counts: self.reduced_stem_vocab_idx_counts[self.mapped_stem_vocab_idx[i]] = self._stem_vocab_idx_counts[i] class AffixSetVocab: def __init__(self, reduced_affix_dict_file = None, reduced_affix_dict_map_file = None): self.affix_set_vocab_idx = dict() self.affix_set_vocab = dict() self.reduced_affix_dict_counts = dict() self.reduced_affix_dict_map = dict() if reduced_affix_dict_file is not None: f = open(reduced_affix_dict_file, 'r') dict_lines = [line.rstrip('\n') for line in f] f.close() idx = 1 for l in dict_lines: if len(l) > 0: spl = l.split(',') if (len(spl) == 2): self.reduced_affix_dict_counts[spl[0]] = int(spl[1]) self.affix_set_vocab_idx[spl[0]] = idx self.affix_set_vocab[idx] = spl[0] idx += 1 if reduced_affix_dict_map_file is not None: f = open(reduced_affix_dict_map_file, 'r') dict_lines = [line.rstrip('\n') for line in f] f.close() for l in dict_lines: if len(l) > 0: spl = l.split(',') if (len(spl) == 2): self.reduced_affix_dict_map[spl[0]] = spl[1] if (spl[1]) in self.affix_set_vocab_idx: self.affix_set_vocab_idx[spl[0]] = self.affix_set_vocab_idx[spl[1]] else: self.affix_set_vocab_idx[spl[0]] = 1 def affix_set_to_idx(self, key): if key in self.affix_set_vocab_idx: return self.affix_set_vocab_idx[key] else: return 1 # N/A def affix_set_idx_to_txt(self, idx, kb_vocab: KBVocab): if idx == 1: return "N/A" elif idx in self.affix_set_vocab: return '-'.join([kb_vocab.affix_vocab_idx[int(x)] for x in self.affix_set_vocab[idx].split('-')]) else: return "UNK" # N/A def random_idx(self): return random.randint(1, len(self.affix_set_vocab_idx)) def state_dict(self): return {'affix_set_vocab_idx':self.affix_set_vocab_idx, 'reduced_affix_dict_counts':self.reduced_affix_dict_counts, 'reduced_affix_dict_map':self.reduced_affix_dict_map} def load_state_dict(self, d): self.affix_set_vocab_idx = d['affix_set_vocab_idx'] self.reduced_affix_dict_counts = d['reduced_affix_dict_counts'] self.reduced_affix_dict_map = d['reduced_affix_dict_map'] class ParsedToken: def __init__(self, surface_form, parsed_token=None, decode_prob=None, tf_idf=0.0, pos_tag_id=None, stem_ids=None, line_num=0): self.surface_form = surface_form self.tf_idf = tf_idf if parsed_token is not None: parts = parsed_token.split('/') self.decode_prob = float(parts[0]) self.tf_idf = float(parts[1]) morphs = parts[2].split(',') pos_stem = morphs[0].split(':') stem_parts = pos_stem[1].split('*') if(len(stem_parts[0]) < 1): # print('\nParsing wrong token: /{}/ at line # {}'.format(parsed_token, line_num)) self.pos_tag_idx = int(pos_stem[0]) self.stem_idx = [6] self.morpho_slots_idx = [] self.affixes_idx = [] else: self.pos_tag_idx = int(pos_stem[0]) self.stem_idx = [int(v) for v in stem_parts] self.morpho_slots_idx = [int(morphs[i].split(':')[0]) for i in range(1, len(morphs))] self.affixes_idx = [int(morphs[i].split(':')[1]) for i in range(1, len(morphs))] else: self.decode_prob = decode_prob self.pos_tag_idx = pos_tag_id self.stem_idx = stem_ids self.morpho_slots_idx = [] self.affixes_idx = [] def append_morpheme(self, morpho_slot_id, affix_id): self.morpho_slots_idx.append(morpho_slot_id) self.affixes_idx.append(affix_id) def to_parsed_format(self): st = ['{}:{}'.format(self.pos_tag_idx,'*'.join([str(i) for i in self.stem_idx]))] for i in range(len(self.morpho_slots_idx)): st.append('{}:{}'.format(self.morpho_slots_idx[i], self.affixes_idx[i])) return '{:.3g}/{:.3g}/{}'.format(self.decode_prob, self.tf_idf, ','.join(st)) def affix_set_key(self): key = '-'.join([str(af) for af in self.affixes_idx]) if (len(self.affixes_idx) > 0) else 'N/A' return key def update_tf_idf_from_idf(doc_sentences): doc_voc = dict() doc_size = 0.0 for parsed_tokens in doc_sentences: for t in parsed_tokens: doc_size += 1.0 v = 0.0 if(t.stem_idx[0] in doc_voc): v = doc_voc[t.stem_idx[0]] doc_voc[t.stem_idx[0]] = v + 1.0 for parsed_tokens in doc_sentences: for t in parsed_tokens: t.tf_idf = sigmoid_score(t.tf_idf * doc_voc[t.stem_idx[0]] / doc_size, 0.01, 0.24) def pre_process_parsed_corpus_compute_tfidf(input_corpus, output_corpus): f = open(input_corpus, 'r') Lines = f.readlines() f.close() outfile = open(output_corpus, 'w') doc_idx = [i for i in range(len(Lines)) if (len(Lines[i]) == 1)] if doc_idx[-1] < (len(Lines) - 1): doc_idx.append(len(Lines)) start_idx = 0 all_docs = len(doc_idx) print_docs = 0 tot = 0 for end_idx in doc_idx: tot = tot + 1 lines_batch = Lines[start_idx:end_idx] start_idx = end_idx + 1 if (len(lines_batch) > 0): doc_sentences = [] for ln in lines_batch: line = ln.strip() line = line.strip('\n') line = line.strip('\t') line = line.strip('\r') splits = line.split('; ') if (len(splits) > 0): if ((len(splits[0]) > 4) and ('/' in splits[0]) and (':' in splits[0])): doc_sentences.append([ParsedToken('_', parsed_token=t) for t in splits]) update_tf_idf_from_idf(doc_sentences) for sent in doc_sentences: outfile.write('; '.join([tok.to_parsed_format() for tok in sent]) + "\n") outfile.write("\n") outfile.flush() print_docs += 1 outfile.close() print('Exported: {} / {} / {}'.format(print_docs, tot, all_docs)) def parse_raw_text_lines(doc_lines, kb_vocab, bpe): from kinlpmorpholib import ffi, lib parsed_tokens = [] num_sent = ffi.new("int[1]") sentences = lib.parse_sentences_batch(doc_lines.encode('utf-8'), num_sent) for i in range(num_sent[0]): sent = sentences[i] for j in range(sent.words_len): w = sent.words[j] POS_TAG = ffi.string(w.pos_tag).decode("utf-8") WORD_TYPE = ffi.string(w.pos_group).decode("utf-8") STEM = ffi.string(w.stem).decode("utf-8") SURFACE_FORM = ffi.string(w.surface_form).decode("utf-8") DECODE_PROB = w.decode_prob TF_IDF = w.tf_idf pti = kb_vocab.pos_tag_vocab[''] if POS_TAG in kb_vocab.pos_tag_vocab.keys(): pti = kb_vocab.pos_tag_vocab[POS_TAG] sids = [] if ((STEM == SURFACE_FORM) and (w.morphemes_len <= 0)): list_sub_words = bpe.encode(SURFACE_FORM, output_type=yttm.OutputType.SUBWORD) for sub_word in list_sub_words: stem_key = WORD_TYPE + ":" + sub_word si = kb_vocab._stem_vocab[''] if stem_key in kb_vocab._stem_vocab.keys(): si = kb_vocab._stem_vocab[stem_key] sids.append(si) else: stem_key = WORD_TYPE + ":" + STEM si = kb_vocab._stem_vocab[''] if stem_key in kb_vocab._stem_vocab.keys(): si = kb_vocab._stem_vocab[stem_key] sids.append(si) ptoken = ParsedToken(SURFACE_FORM, parsed_token=None, decode_prob=DECODE_PROB, tf_idf=TF_IDF, pos_tag_id=pti, stem_ids=sids) if (w.morphemes_len > 0): for k in range(w.morphemes_len): if ((k != w.stem_start_index) and (k != w.stem_end_index)): MORPHEME_SLOT = WORD_TYPE + ":" + str(w.morphemes[k].slot_id) MORPHEME = MORPHEME_SLOT + ":" + ffi.string(w.morphemes[k].morph_token).decode("utf-8") msi = kb_vocab.morpheme_slot_vocab[''] if MORPHEME_SLOT in kb_vocab.morpheme_slot_vocab.keys(): msi = kb_vocab.morpheme_slot_vocab[MORPHEME_SLOT] mi = kb_vocab.affix_vocab[''] if MORPHEME in kb_vocab.affix_vocab.keys(): mi = kb_vocab.affix_vocab[MORPHEME] ptoken.append_morpheme(msi, mi) parsed_tokens.append(ptoken) lib.release_sentence(sentences, num_sent) return parsed_tokens def process_parsed_sentence(args, parsed_tokens_list: List[ParsedToken], add_cls, kv : KBVocab, affix_set_vocab : AffixSetVocab, rel_pos_dict, rel_pos_dmax): # Input data pos_tags = [] stems = [] afsets = [] if args.use_afsets else None affixes = [] tokens_lengths = [] # Predicted data predicted_stems = [] predicted_afsets = [] if args.use_afsets else None predicted_affixes = [] if args.predict_affixes else None predicted_tokens_idx = [] predicted_tokens_affixes_idx = [] if args.predict_affixes else None predicted_tokens_affixes_lengths = [] if args.predict_affixes else None # Add Token if add_cls: pos_tags.append(kv.pos_tag_vocab['']) stems.append(kv.reduced_stem_vocab['']) if args.use_afsets: afsets.append(affix_set_vocab.affix_set_to_idx('')) tokens_lengths.append(0) if (len(parsed_tokens_list) == 0): # New document pos_tags.append(kv.pos_tag_vocab['']) stems.append(kv.reduced_stem_vocab['']) if args.use_afsets: afsets.append(affix_set_vocab.affix_set_to_idx('')) tokens_lengths.append(0) else: for ptoken in parsed_tokens_list: for sidx in ptoken.stem_idx: unchanged = True predict = False rval = random.random() if (rval <= 0.15): # 15% of tokens are predicted predict = True rval /= 0.15 if(rval < 0.8): # 80% of predicted tokens are masked unchanged = False pos_tags.append(kv.pos_tag_vocab['']) stems.append(kv.reduced_stem_vocab['']) if args.use_afsets: afsets.append(affix_set_vocab.affix_set_to_idx('')) vv = rval/0.8 if vv < 0.3: # Include Affixes for 30% of the time to enforce morphology learning affixes.extend([(v) for v in ptoken.affixes_idx]) tokens_lengths.append(len(ptoken.affixes_idx)) else: tokens_lengths.append(0) elif (rval < 0.9): # 10% are replaced by random tokens, 10% are left unchanged unchanged = False rnd_pos = random.randint(kv.pos_tag_vocab[''], len(kv.pos_tag_vocab)-1) rnd_stem = random.randint(kv.reduced_stem_vocab[''], len(kv.reduced_stem_vocab)-1) pos_tags.append(rnd_pos) stems.append(rnd_stem) if args.use_afsets: afsets.append(affix_set_vocab.random_idx()) elif args.inference_model_file is not None: affix_set_vocab.random_idx() vv = rval/0.8 if vv < 0.3: # Include Affixes for 30% of the time to enforce morphology learning affixes.extend([(v) for v in ptoken.affixes_idx]) tokens_lengths.append(len(ptoken.affixes_idx)) else: tokens_lengths.append(0) if(unchanged): pos_tags.append(ptoken.pos_tag_idx) stems.append(kv.mapped_stem_vocab_idx[sidx]) if args.use_afsets: afsets.append(affix_set_vocab.affix_set_to_idx(ptoken.affix_set_key())) affixes.extend([(v) for v in ptoken.affixes_idx]) tokens_lengths.append(len(ptoken.affixes_idx)) if(predict): predicted_stems.append(kv.mapped_stem_vocab_idx[sidx]) predicted_tokens_idx.append(len(tokens_lengths) - 1) if args.use_afsets: predicted_afsets.append(affix_set_vocab.affix_set_to_idx(ptoken.affix_set_key())) if args.predict_affixes: predicted_affixes.extend([(v) for v in ptoken.affixes_idx]) if(len(ptoken.affixes_idx) > 0): predicted_tokens_affixes_idx.append(len(predicted_tokens_idx) - 1) predicted_tokens_affixes_lengths.append(len(ptoken.affixes_idx)) rel_pos_arr = np.zeros((len(pos_tags), len(pos_tags))).astype(int) if (rel_pos_dict is not None) else None if rel_pos_arr is not None: for i,pi in enumerate(pos_tags): for j,pj in enumerate(pos_tags): d = i - j if ((d != 0) and (d >= -rel_pos_dmax) and (d <= rel_pos_dmax) and ((not add_cls) or ((i != 0) and (j != 0)))): rel_pos_arr[i, j] = rel_pos_dict[(pi, pj, d)] return (rel_pos_arr, pos_tags, stems, afsets, affixes, tokens_lengths, predicted_stems, predicted_afsets, predicted_affixes, predicted_tokens_idx, predicted_tokens_affixes_idx, predicted_tokens_affixes_lengths) def gather_replicated_itemized_data(args, corpus_lines, doc_ends, is_corpus_parsed, max_seq_len, start_line, max_batch_items, kb_vocab : KBVocab, affix_set_vocab : AffixSetVocab, bpe: yttm.BPE, rel_pos_dict, rel_pos_dmax, rank=0, bar=None, num_lines=sys.maxsize,shuffle=False,are_parsed_tokens_split=False): itemized_data = [] itemized_parsed_tokens = [] seq_pos_tags = [] seq_stems = [] seq_afsets = [] if args.use_afsets else None seq_affixes = [] seq_tokens_lengths = [] seq_predicted_stems = [] seq_predicted_afsets = [] if args.use_afsets else None seq_predicted_affixes = [] if args.predict_affixes else None seq_predicted_tokens_idx = [] seq_predicted_tokens_affixes_idx = [] if args.predict_affixes else None seq_predicted_tokens_affixes_lengths = [] if args.predict_affixes else None seq_rel_pos_arr = np.zeros((max_seq_len, max_seq_len)).astype(int) if (rel_pos_dict is not None) else None seq_parsed_tokens = [] # if (rank==0): # print(time_now(), 'Gathering itemized input sequence data from line # {}'.format(start_line + 1), flush=True) lcount = 0 while (True): lcount += 1 add_cls = (len(seq_parsed_tokens) == 0) if add_cls: ptoken = ParsedToken('', decode_prob=1.0, tf_idf=0.001, pos_tag_id=kb_vocab.pos_tag_vocab[''], stem_ids=[kb_vocab.reduced_stem_vocab['']]) ptoken.append_morpheme(kb_vocab.morpheme_slot_vocab[''], kb_vocab.affix_vocab['']) seq_parsed_tokens.append(ptoken) line = corpus_lines[start_line % len(corpus_lines)] start_line = (start_line + 1) % len(corpus_lines) parsed_tokens_line = [] if (is_corpus_parsed == True): if are_parsed_tokens_split: parsed_tokens_line = line else: splits = line.split('; ') if (len(splits) > 0): if ((len(splits[0]) > 4) and ('/' in splits[0]) and (':' in splits[0])): parsed_tokens_line = [ParsedToken('_', parsed_token=t, line_num=(start_line + 1)) for t in splits] else: parsed_tokens_line = parse_raw_text_lines(line, kb_vocab, bpe) if (len(parsed_tokens_line) == 0): ptoken = ParsedToken('', decode_prob=1.0, tf_idf=0.001, pos_tag_id=kb_vocab.pos_tag_vocab[''], stem_ids=[kb_vocab.reduced_stem_vocab['']]) ptoken.append_morpheme(kb_vocab.morpheme_slot_vocab[''], kb_vocab.affix_vocab['']) seq_parsed_tokens.append(ptoken) else: seq_parsed_tokens.extend(parsed_tokens_line) (rel_pos_arr, pos_tags, stems, afsets, affixes, tokens_lengths, predicted_stems, predicted_afsets, predicted_affixes, predicted_tokens_idx, predicted_tokens_affixes_idx, predicted_tokens_affixes_lengths) = process_parsed_sentence(args, parsed_tokens_line, add_cls, kb_vocab, affix_set_vocab, rel_pos_dict, rel_pos_dmax) if (len(seq_tokens_lengths) + len(tokens_lengths)) > max_seq_len: data_item = (max_seq_len, seq_rel_pos_arr, seq_pos_tags, seq_stems, seq_afsets, seq_affixes, seq_tokens_lengths, seq_predicted_stems, seq_predicted_afsets, seq_predicted_affixes, seq_predicted_tokens_idx, seq_predicted_tokens_affixes_idx, seq_predicted_tokens_affixes_lengths) itemized_data.append(data_item) itemized_parsed_tokens.append(seq_parsed_tokens) if (shuffle): dcx = random.randint(0, len(doc_ends) - 1) % len(doc_ends) start_line = (doc_ends[dcx]+1) % len(corpus_lines) if random.random() < 0.8: # 80% of the time, start from anywhere within the corpus. start_line = random.randint(0, len(corpus_lines) - 1) % len(corpus_lines) if (len(itemized_data) >= max_batch_items): if ((rank == 0) and (bar is not None)): bar.update(len(itemized_data)) sys.stdout.flush() return itemized_data, itemized_parsed_tokens, start_line if(lcount >= num_lines): return itemized_data, itemized_parsed_tokens, start_line seq_pos_tags = [] seq_stems = [] seq_afsets = [] if args.use_afsets else None seq_affixes = [] seq_tokens_lengths = [] seq_predicted_stems = [] seq_predicted_afsets = [] if args.use_afsets else None seq_predicted_affixes = [] if args.predict_affixes else None seq_predicted_tokens_idx = [] seq_predicted_tokens_affixes_idx = [] if args.predict_affixes else None seq_predicted_tokens_affixes_lengths = [] if args.predict_affixes else None seq_rel_pos_arr = np.zeros((max_seq_len, max_seq_len)).astype(int) if (rel_pos_dict is not None) else None seq_parsed_tokens = [] if ((rank == 0) and ((len(itemized_data) % (math.floor(0.1 * max_batch_items) + 1)) == 0) and (bar is not None)): bar.update(len(itemized_data)) sys.stdout.flush() else: if args.predict_affixes: seq_predicted_tokens_affixes_idx.extend([len(seq_predicted_tokens_idx) + idx for idx in predicted_tokens_affixes_idx]) seq_predicted_tokens_idx.extend([len(seq_tokens_lengths)+idx for idx in predicted_tokens_idx]) lt = len(seq_tokens_lengths) if seq_rel_pos_arr is not None: seq_rel_pos_arr[lt:(lt+rel_pos_arr.shape[0]), lt:(lt+rel_pos_arr.shape[1])] = rel_pos_arr seq_pos_tags.extend(pos_tags) seq_stems.extend(stems) if args.use_afsets: seq_afsets.extend(afsets) seq_affixes.extend(affixes) seq_tokens_lengths.extend(tokens_lengths) seq_predicted_stems.extend(predicted_stems) if args.use_afsets: seq_predicted_afsets.extend(predicted_afsets) if args.predict_affixes: seq_predicted_affixes.extend(predicted_affixes) seq_predicted_tokens_affixes_lengths.extend(predicted_tokens_affixes_lengths) if(lcount >= num_lines): data_item = (max_seq_len, seq_rel_pos_arr, seq_pos_tags, seq_stems, seq_afsets, seq_affixes, seq_tokens_lengths, seq_predicted_stems, seq_predicted_afsets, seq_predicted_affixes, seq_predicted_tokens_idx, seq_predicted_tokens_affixes_idx, seq_predicted_tokens_affixes_lengths) itemized_data.append(data_item) itemized_parsed_tokens.append(seq_parsed_tokens) return itemized_data, itemized_parsed_tokens, start_line def morpho_seq_collate_wrapper(batch_items): batch_pos_tags = [] batch_stems = [] batch_afsets = [] batch_affixes = [] batch_tokens_lengths = [] batch_predicted_stems = [] batch_predicted_afsets = [] batch_predicted_affixes = [] batch_predicted_tokens_idx = [] batch_predicted_tokens_affixes_idx = [] batch_predicted_tokens_affixes_lengths = [] batch_input_sequence_lengths = [] max_sequence_len = batch_items[0][0] first_seq_rel_pos_arr = batch_items[0][1] batch_rel_pos_arr = np.zeros((len(batch_items), max_sequence_len, max_sequence_len)).astype(int) if (first_seq_rel_pos_arr is not None) else None for bidx,data_item in enumerate(batch_items): (max_seq_len, seq_rel_pos_arr, seq_pos_tags, seq_stems, seq_afsets, seq_affixes, seq_tokens_lengths, seq_predicted_stems, seq_predicted_afsets, seq_predicted_affixes, seq_predicted_tokens_idx, seq_predicted_tokens_affixes_idx, seq_predicted_tokens_affixes_lengths) = data_item if batch_rel_pos_arr is not None: batch_rel_pos_arr[bidx,:,:] = seq_rel_pos_arr # Need to offset from total number of predicted token indices if seq_predicted_tokens_affixes_idx is not None: batch_predicted_tokens_affixes_idx.extend([(len(batch_predicted_tokens_idx) + t) for t in seq_predicted_tokens_affixes_idx]) batch_predicted_tokens_idx.extend([(t, len(batch_input_sequence_lengths)) for t in seq_predicted_tokens_idx]) batch_pos_tags.extend(seq_pos_tags) batch_stems.extend(seq_stems) if seq_afsets is not None: batch_afsets.extend(seq_afsets) batch_affixes.extend(seq_affixes) batch_tokens_lengths.extend(seq_tokens_lengths) batch_predicted_stems.extend(seq_predicted_stems) if seq_predicted_afsets is not None: batch_predicted_afsets.extend(seq_predicted_afsets) if seq_predicted_affixes is not None: batch_predicted_affixes.extend(seq_predicted_affixes) batch_predicted_tokens_affixes_lengths.extend(seq_predicted_tokens_affixes_lengths) batch_input_sequence_lengths.append(len(seq_tokens_lengths)) data_item = (batch_input_sequence_lengths, batch_rel_pos_arr, batch_pos_tags, batch_stems, batch_afsets, batch_affixes, batch_tokens_lengths, batch_predicted_stems, batch_predicted_afsets, batch_predicted_affixes, batch_predicted_tokens_idx, batch_predicted_tokens_affixes_idx, batch_predicted_tokens_affixes_lengths) return data_item class KBCorpusDataset(Dataset): def __init__(self,args, kb_vocab : KBVocab, affix_set_vocab : AffixSetVocab, bpe_encoder: yttm.BPE, rel_pos_dict, rel_pos_dmax, parsed_corpus_lines, doc_ends, is_corpus_parsed, start_line, max_batch_items, max_seq_len = 512, rank = 0): self.max_seq_len = max_seq_len self.start_line = start_line self.max_batch_items = max_batch_items if (rank==0): with progressbar.ProgressBar(max_value=max_batch_items, redirect_stdout=True) as bar: self.itemized_data, self.itemized_parsed_tokens, self.start_line = gather_replicated_itemized_data(args, parsed_corpus_lines, doc_ends, is_corpus_parsed, self.max_seq_len, self.start_line, self.max_batch_items, kb_vocab, affix_set_vocab, bpe_encoder, rel_pos_dict, rel_pos_dmax, rank = rank, bar=bar, shuffle = True) else: self.itemized_data, self.itemized_parsed_tokens, self.start_line = gather_replicated_itemized_data(args, parsed_corpus_lines, doc_ends, is_corpus_parsed, self.max_seq_len, self.start_line, self.max_batch_items, kb_vocab, affix_set_vocab, bpe_encoder, rel_pos_dict, rel_pos_dmax, rank=rank, bar=None, shuffle = True) def __len__(self): return len(self.itemized_data) def __getitem__(self, idx): return self.itemized_data[idx] from itertools import accumulate from morpho_model import KinyaBERT def morpho_model_forward(args, data_item, model : KinyaBERT, device, tot_num_affixes): (batch_input_sequence_lengths, batch_rel_pos_arr, batch_pos_tags, batch_stems, batch_afsets, batch_affixes, batch_tokens_lengths, batch_predicted_stems, batch_predicted_afsets, batch_predicted_affixes, batch_predicted_tokens_idx, batch_predicted_tokens_affixes_idx, batch_predicted_tokens_affixes_lengths) = data_item tokens_lengths = batch_tokens_lengths # torch.tensor(batch_tokens_lengths).to(device) input_sequence_lengths = batch_input_sequence_lengths # torch.tensor(batch_input_sequence_lengths).to(device) pos_tags = torch.tensor(batch_pos_tags).to(device) stems = torch.tensor(batch_stems).to(device) afsets = torch.tensor(batch_afsets).to(device) if args.use_afsets else None affixes = torch.tensor(batch_affixes).to(device) predicted_tokens_idx = torch.tensor([s * max(batch_input_sequence_lengths) + t for t, s in batch_predicted_tokens_idx]).to(device) predicted_tokens_affixes_idx = torch.tensor(batch_predicted_tokens_affixes_idx).to(device) if args.predict_affixes else None predicted_affixes_prob = None if args.predict_affixes: pred_affixes_list = [batch_predicted_affixes[x - y: x] for x, y in zip(accumulate(batch_predicted_tokens_affixes_lengths), batch_predicted_tokens_affixes_lengths)] afx_prob = torch.zeros(len(pred_affixes_list), tot_num_affixes) for i,lst in enumerate(pred_affixes_list): assert (len(lst) > 0) afx_prob[i,lst] = 1 / len(lst) predicted_affixes_prob = afx_prob.to(device) predicted_stems = torch.tensor(batch_predicted_stems).to(device) predicted_afsets = torch.tensor(batch_predicted_afsets).to(device) if args.use_afsets else None rel_pos_arr = torch.from_numpy(batch_rel_pos_arr).to(device) if (batch_rel_pos_arr is not None) else None return model(args, rel_pos_arr, tokens_lengths, input_sequence_lengths, pos_tags, stems, afsets, affixes, predicted_tokens_idx, predicted_tokens_affixes_idx, predicted_stems, predicted_afsets, predicted_affixes_prob) def morpho_model_seq_predict(args, data_item, model : KinyaBERT, device, max_predict_affixes, proposed_stem_ids=None): (max_seq_len, seq_rel_pos_arr, seq_pos_tags, seq_stems, seq_afsets, seq_affixes, seq_tokens_lengths, seq_predicted_stems, seq_predicted_afsets, seq_predicted_affixes, seq_predicted_tokens_idx, seq_predicted_tokens_affixes_idx, seq_predicted_tokens_affixes_lengths) = data_item tokens_lengths = seq_tokens_lengths input_sequence_lengths = [len(seq_tokens_lengths)] pos_tags = torch.tensor(seq_pos_tags).to(device) stems = torch.tensor(seq_stems).to(device) afsets = torch.tensor(seq_afsets).to(device) if args.use_afsets else None affixes = torch.tensor(seq_affixes).to(device) seq_predicted_token_idx = torch.tensor(seq_predicted_tokens_idx).to(device) rel_pos_arr = torch.from_numpy(seq_rel_pos_arr).unsqueeze(0).to(device) if (seq_rel_pos_arr is not None) else None return model.predict(args, rel_pos_arr, tokens_lengths, input_sequence_lengths, pos_tags, stems, afsets, affixes, seq_predicted_token_idx, max_predict_affixes, proposed_stem_ids=proposed_stem_ids)