alusci commited on
Commit
a00e3e2
·
1 Parent(s): 55ec307

Add gradio app and requirements

Browse files
Files changed (2) hide show
  1. app.py +66 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import numpy as np
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+
8
+ # Load model and tokenizer once
9
+ model_name = "alusci/distilbert-smsafe"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True)
12
+ model.eval()
13
+
14
+ # Main function
15
+ def classify_and_plot_attention(text):
16
+ # Tokenize input
17
+ inputs = tokenizer(text, return_tensors="pt")
18
+
19
+ # Forward pass with attention
20
+ with torch.no_grad():
21
+ outputs = model(**inputs)
22
+
23
+ # Get prediction
24
+ logits = outputs.logits
25
+ probs = torch.nn.functional.softmax(logits, dim=-1)
26
+ pred_idx = torch.argmax(probs).item()
27
+ pred_label = model.config.id2label[pred_idx]
28
+ pred_score = round(probs[0, pred_idx].item(), 4)
29
+
30
+ # Extract attention across all layers and heads
31
+ all_attn = torch.stack(outputs.attentions) # (layers, batch, heads, seq_len, seq_len)
32
+ mean_attn = all_attn.mean(dim=(0, 2))[0].numpy() # average over layers & heads
33
+
34
+ # Token filtering (remove CLS/SEP)
35
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
36
+ real_token_idxs = [i for i, tok in enumerate(tokens) if tok not in ("[CLS]", "[SEP]")]
37
+ real_tokens = [tokens[i] for i in real_token_idxs]
38
+ trimmed_attn = mean_attn[np.ix_(real_token_idxs, real_token_idxs)]
39
+
40
+ # Normalize
41
+ norm_attn = (trimmed_attn - trimmed_attn.min()) / (trimmed_attn.max() - trimmed_attn.min())
42
+
43
+ # Plot
44
+ fig, ax = plt.subplots(figsize=(8, 6))
45
+ sns.heatmap(norm_attn, xticklabels=real_tokens, yticklabels=real_tokens,
46
+ cmap="viridis", square=True, ax=ax, cbar=True)
47
+ ax.set_title("Normalized Attention Map")
48
+ ax.set_xlabel("Input Tokens")
49
+ ax.set_ylabel("Output Tokens")
50
+ plt.xticks(rotation=45)
51
+ plt.tight_layout()
52
+
53
+ return f"Prediction: {pred_label} (Score: {pred_score})", fig
54
+
55
+ # Gradio UI
56
+ demo = gr.Interface(
57
+ fn=classify_and_plot_attention,
58
+ inputs=gr.Textbox(lines=3, placeholder="Paste your SMS OTP message here..."),
59
+ outputs=["text", "plot"],
60
+ title="SMS OTP Spam Classifier + Attention Visualizer",
61
+ description="Enter an SMS OTP message to classify it and view the attention matrix.",
62
+ allow_flagging="never"
63
+ )
64
+
65
+ if __name__ == "__main__":
66
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ matplotlib
4
+ seaborn
5
+ gradio
6
+ numpy