Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # Load model and tokenizer once | |
| model_name = "alusci/distilbert-smsafe" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name, output_attentions=True) | |
| model.eval() | |
| # Main function | |
| def classify_and_plot_attention(text): | |
| # Tokenize input | |
| inputs = tokenizer(text, return_tensors="pt") | |
| # Forward pass with attention | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Get prediction | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| pred_idx = torch.argmax(probs).item() | |
| pred_label = model.config.id2label[pred_idx] | |
| pred_score = round(probs[0, pred_idx].item(), 4) | |
| # Extract attention across all layers and heads | |
| all_attn = torch.stack(outputs.attentions) # (layers, batch, heads, seq_len, seq_len) | |
| mean_attn = all_attn.mean(dim=(0, 2))[0].numpy() # average over layers & heads | |
| # Token filtering (remove CLS/SEP) | |
| tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) | |
| real_token_idxs = [i for i, tok in enumerate(tokens) if tok not in ("[CLS]", "[SEP]")] | |
| real_tokens = [tokens[i] for i in real_token_idxs] | |
| trimmed_attn = mean_attn[np.ix_(real_token_idxs, real_token_idxs)] | |
| # Normalize | |
| norm_attn = (trimmed_attn - trimmed_attn.min()) / (trimmed_attn.max() - trimmed_attn.min()) | |
| # Plot | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| sns.heatmap(norm_attn, xticklabels=real_tokens, yticklabels=real_tokens, | |
| cmap="viridis", square=True, ax=ax, cbar=True) | |
| ax.set_title("Normalized Attention Map") | |
| ax.set_xlabel("Input Tokens") | |
| ax.set_ylabel("Output Tokens") | |
| plt.xticks(rotation=45) | |
| plt.tight_layout() | |
| return f"Prediction: {pred_label} (Score: {pred_score})", fig | |
| # Gradio UI | |
| demo = gr.Interface( | |
| fn=classify_and_plot_attention, | |
| inputs=gr.Textbox(lines=3, placeholder="Paste your SMS OTP message here..."), | |
| outputs=["text", "plot"], | |
| title="SMS OTP Spam Classifier + Attention Visualizer", | |
| description="Enter an SMS OTP message to classify it and view the attention matrix.", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |