Spaces:
Runtime error
Runtime error
Apoorv Saxena
commited on
Commit
·
e372099
1
Parent(s):
0c954bd
Update app.py
Browse files
app.py
CHANGED
|
@@ -50,22 +50,29 @@ def greedyPredict(input, model, tokenizer):
|
|
| 50 |
def predict_tail(entity, relation):
|
| 51 |
global model, tokenizer
|
| 52 |
input = entity + "| " + relation
|
| 53 |
-
out = topkSample(input, model, tokenizer, num_samples=
|
| 54 |
out_dict = {}
|
| 55 |
for k, v in out:
|
| 56 |
out_dict[k] = np.exp(v).item()
|
| 57 |
return out_dict
|
| 58 |
|
| 59 |
|
| 60 |
-
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
|
| 61 |
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2")
|
| 62 |
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
rel_input = gr.inputs.Textbox(lines=1, default="followed by")
|
| 67 |
output = gr.outputs.Label()
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
iface.launch()
|
|
|
|
| 50 |
def predict_tail(entity, relation):
|
| 51 |
global model, tokenizer
|
| 52 |
input = entity + "| " + relation
|
| 53 |
+
out = topkSample(input, model, tokenizer, num_samples=25)
|
| 54 |
out_dict = {}
|
| 55 |
for k, v in out:
|
| 56 |
out_dict[k] = np.exp(v).item()
|
| 57 |
return out_dict
|
| 58 |
|
| 59 |
|
| 60 |
+
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2")
|
| 61 |
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2")
|
| 62 |
|
| 63 |
|
| 64 |
+
ent_input = gr.inputs.Textbox(lines=1, default="Apoorv Umang Saxena")
|
| 65 |
+
rel_input = gr.inputs.Textbox(lines=1, default="country")
|
|
|
|
| 66 |
output = gr.outputs.Label()
|
| 67 |
|
| 68 |
+
examples = [
|
| 69 |
+
['Adrian Kochsiek', 'gender'],
|
| 70 |
+
['Apoorv Umang Saxena', 'family name'],
|
| 71 |
+
['World War II', 'followed by'],
|
| 72 |
+
['Apoorv Umang Saxena', 'country']
|
| 73 |
+
]
|
| 74 |
+
iface = gr.Interface(fn=predict_tail,
|
| 75 |
+
inputs=[ent_input, rel_input],
|
| 76 |
+
outputs=output,
|
| 77 |
+
examples=examples,)
|
| 78 |
iface.launch()
|