| import torch |
| from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
| def pretty_print(text, prompt=True): |
| s = "" |
| if prompt: |
| for section in text.split(', '): |
| premises = section.split(" and ") |
| if len(premises) > 1: |
| for premise in premises[:-1]: |
| s += premise + "\n\n\n" + "and" + "\n\n\n" |
| s += premises[-1] + "\n\n\n" |
| else: |
| s += section + "\n\n\n" |
| else: |
| for equation in text.split("and"): |
| s += equation + "\n\n\n" |
| return print(s[:-3]) |
|
|
|
|
| def load_model(model_id): |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| tokenizer = T5Tokenizer.from_pretrained(model_id) |
| model = T5ForConditionalGeneration.from_pretrained(model_id).to(device) |
| return tokenizer, model |
|
|
|
|
| def inference(prompt, tokenizer, model): |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| input_ids = tokenizer.encode(prompt, return_tensors='pt', max_length=512, truncation=True).to(device) |
| output = model.generate(input_ids=input_ids, max_length=512, early_stopping=True) |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
| |
| |
| derivation = generated_text.replace("\\ ","\\") |
| partial_symbols = derivation.split(" ") |
| backslash_syms = set([i for i in partial_symbols if "\\" in i]) |
| for i in range(len(partial_symbols)): |
| sym = partial_symbols[i] |
| for b_sym in backslash_syms: |
| if b_sym.replace("\\","") == sym: |
| partial_symbols[i] = b_sym |
| return " ".join(partial_symbols) |
|
|