Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import BertTokenizer, BertModel | |
| from huggingface_hub import hf_hub_url, cached_download | |
| def get_cls_layer(repo_id="furrutiav/beto_coherence"): | |
| config_file_url = hf_hub_url(repo_id, filename="cls_layer.torch") | |
| value = cached_download(config_file_url) | |
| return torch.load(value, map_location=torch.device('cpu')) | |
| cls_layer = get_cls_layer() | |
| beto_model = BertModel.from_pretrained("furrutiav/beto_coherence", revision="df96f50cfb1e3f7923912a25b1c3a865116fae4a") | |
| beto_tokenizer = BertTokenizer.from_pretrained("furrutiav/beto_coherence", revision="df96f50cfb1e3f7923912a25b1c3a865116fae4a", do_lower_case=False) | |
| e = beto_model.eval() | |
| def preproccesing(Q, A, maxlen=60): | |
| Q = " ".join(str(Q).replace("\n", " ").split()) | |
| A = " ".join(str(A).replace("\n", " ").split()) | |
| Q = Q if Q != "" else "nan" | |
| A = A if A != "" else "nan" | |
| tokens1 = beto_tokenizer.tokenize(Q) | |
| tokens1 = ['[CLS]'] + tokens1 + ['[SEP]'] | |
| if len(tokens1) < maxlen: | |
| tokens1 = tokens1 + ['[PAD]' for _ in range(maxlen - len(tokens1))] | |
| else: | |
| tokens1 = tokens1[:maxlen-1] + ['[SEP]'] | |
| tokens2 = beto_tokenizer.tokenize(A) | |
| tokens2 = tokens2 + ['[SEP]'] | |
| if len(tokens2) < maxlen: | |
| tokens2 = tokens2 + ['[PAD]' for _ in range(maxlen - len(tokens2))] | |
| else: | |
| tokens2 = tokens2[:maxlen-1] + ['[SEP]'] | |
| tokens = tokens1+tokens2 | |
| tokens_ids = beto_tokenizer.convert_tokens_to_ids(tokens) | |
| tokens_ids_tensor = torch.tensor(tokens_ids) | |
| attn_mask = (tokens_ids_tensor != 1).long() | |
| return tokens_ids_tensor, attn_mask | |
| def C1Classifier(Q, A, is_probs=True): | |
| tokens_ids_tensor, attn_mask = preproccesing(Q, A) | |
| cont_reps = beto_model(tokens_ids_tensor.unsqueeze(0), attention_mask = attn_mask.unsqueeze(0)) | |
| cls_rep = cont_reps.last_hidden_state[:, 0] | |
| logits = cls_layer(cls_rep) | |
| probs = torch.sigmoid(logits) | |
| soft_probs = probs.argmax(1) | |
| if is_probs: | |
| return probs.detach().numpy()[0] | |
| else: | |
| return soft_probs.numpy()[0] | |