| | import torch |
| | import pickle |
| |
|
| | |
| | TOKEN = '.' |
| |
|
| | |
| | words = open('data/names.txt','r').read().splitlines() |
| |
|
| | |
| | vocab = sorted(list(set(''.join(words)) | {TOKEN})) |
| |
|
| | |
| | n = len(vocab) |
| | N = torch.zeros((n,n), dtype = torch.int32) |
| |
|
| | |
| | char_to_int = {char:i for i,char in enumerate(vocab)} |
| | int_to_char = {value:key for key,value in char_to_int.items()} |
| |
|
| | |
| | for word in words: |
| | chars = [TOKEN] + list(word) + [TOKEN] |
| | for ch1,ch2 in zip(chars,chars[1:]): |
| | ix1 = char_to_int[ch1] |
| | ix2 = char_to_int[ch2] |
| | N[ix1,ix2] += 1 |
| |
|
| | |
| | P = N.float() |
| | P /= P.sum(1, keepdim = True) |
| |
|
| | |
| | with open('model/bigrams.pkl', 'wb') as file: |
| | pickle.dump([P,char_to_int,int_to_char], file) |