Spaces:
Sleeping
Sleeping
File size: 32,468 Bytes
dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 816d43f dd5a03c 816d43f dd5a03c 816d43f be5c319 816d43f 2b445e9 816d43f dd5a03c be5c319 9bf5c2d be5c319 dd5a03c 816d43f be5c319 816d43f be5c319 9bf5c2d be5c319 9bf5c2d be5c319 9bf5c2d be5c319 9bf5c2d 816d43f be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c 816d43f be5c319 dd5a03c 816d43f be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c 816d43f be5c319 dd5a03c be5c319 816d43f dd5a03c be5c319 816d43f be5c319 dd5a03c be5c319 9bf5c2d dd5a03c be5c319 816d43f be5c319 dd5a03c be5c319 816d43f be5c319 816d43f 9bf5c2d be5c319 816d43f be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c 816d43f be5c319 dd5a03c 816d43f be5c319 dd5a03c be5c319 dd5a03c be5c319 816d43f be5c319 dd5a03c be5c319 dd5a03c 816d43f be5c319 dd5a03c be5c319 816d43f be5c319 816d43f be5c319 816d43f 9bf5c2d be5c319 816d43f be5c319 dd5a03c be5c319 816d43f be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c 816d43f be5c319 dd5a03c 816d43f dd5a03c 816d43f be5c319 dd5a03c be5c319 dd5a03c be5c319 816d43f be5c319 dd5a03c be5c319 816d43f be5c319 816d43f be5c319 816d43f 9bf5c2d be5c319 816d43f be5c319 dd5a03c be5c319 816d43f be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c 816d43f be5c319 dd5a03c 816d43f dd5a03c 816d43f be5c319 dd5a03c be5c319 816d43f be5c319 816d43f be5c319 816d43f 9bf5c2d be5c319 816d43f be5c319 dd5a03c be5c319 816d43f be5c319 dd5a03c be5c319 dd5a03c be5c319 dd5a03c 816d43f dd5a03c 816d43f dd5a03c 38107f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 |
# app.py
import os
import sys
import time
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
# Add src to path
sys.path.append(os.path.join(os.path.dirname(__file__), "src"))
from auditor import create_auditors
from explainer import explain_attention, explain_gradcam, explain_gradient_shap
from model_loader import SUPPORTED_MODELS, load_model_and_processor
from predictor import create_prediction_plot, predict_image
from utils import get_top_predictions_dict, preprocess_image
# Global variables to cache model and processor
model = None
processor = None
current_model_name = None
auditors = None
def load_selected_model(model_name):
"""Load the selected model and cache it globally."""
global model, processor, current_model_name, auditors
try:
if model is None or current_model_name != model_name:
print(f"Loading model: {model_name}")
model, processor = load_model_and_processor(model_name)
current_model_name = model_name
# Initialize auditors
auditors = create_auditors(model, processor)
print("β
Model and auditors loaded successfully!")
return f"β
Model loaded: {model_name}"
except Exception as e:
return f"β Error loading model: {str(e)}"
def analyze_image_basic(image, model_choice, xai_method, layer_index, head_index):
"""
Basic explainability analysis - the core function for Tab 1.
"""
try:
# Load model if needed
model_status = load_selected_model(SUPPORTED_MODELS[model_choice])
if "β" in model_status:
return None, None, None, model_status
# Preprocess image
if image is None:
return None, None, None, "β οΈ Please upload an image first."
processed_image = preprocess_image(image)
# Get predictions
probs, indices, labels = predict_image(processed_image, model, processor)
pred_fig = create_prediction_plot(probs, labels)
# Generate explanation based on selected method
explanation_fig = None
explanation_image = None
if xai_method == "Attention Visualization":
explanation_fig = explain_attention(
model, processor, processed_image, layer_index=layer_index, head_index=head_index
)
elif xai_method == "GradCAM":
explanation_fig, explanation_image = explain_gradcam(model, processor, processed_image)
elif xai_method == "GradientSHAP":
explanation_fig = explain_gradient_shap(model, processor, processed_image, n_samples=3)
# Convert predictions to dictionary for Gradio Label
pred_dict = get_top_predictions_dict(probs, labels)
return (
processed_image,
pred_fig,
explanation_fig,
f"β
Analysis complete! Top prediction: {labels[0]} ({probs[0]:.2%})",
)
except Exception as e:
error_msg = f"β Analysis failed: {str(e)}"
print(error_msg)
return None, None, None, error_msg
def analyze_counterfactual(image, model_choice, patch_size, perturbation_type):
"""
Counterfactual analysis for Tab 2.
"""
try:
# Load model if needed
model_status = load_selected_model(SUPPORTED_MODELS[model_choice])
if "β" in model_status:
return None, None, model_status
if image is None:
return None, None, "β οΈ Please upload an image first."
processed_image = preprocess_image(image)
# Perform counterfactual analysis
results = auditors["counterfactual"].patch_perturbation_analysis(
processed_image, patch_size=patch_size, perturbation_type=perturbation_type
)
# Create summary message
summary = (
f"π Counterfactual Analysis Complete!\n"
f"β’ Avg confidence change: {results['avg_confidence_change']:.4f}\n"
f"β’ Prediction flip rate: {results['prediction_flip_rate']:.2%}\n"
f"β’ Most sensitive patch: {results['most_sensitive_patch']}"
)
return results["figure"], summary
except Exception as e:
error_msg = f"β Counterfactual analysis failed: {str(e)}"
print(error_msg)
return None, error_msg
def analyze_calibration(image, model_choice, n_bins):
"""
Confidence calibration analysis for Tab 3.
"""
try:
# Load model if needed
model_status = load_selected_model(SUPPORTED_MODELS[model_choice])
if "β" in model_status:
return None, None, model_status
if image is None:
return None, None, "β οΈ Please upload an image first."
processed_image = preprocess_image(image)
# For demo purposes, create a simple test set from the uploaded image
# In a real scenario, you'd use a proper validation set
test_images = [processed_image] * 10 # Create multiple copies
# Perform calibration analysis
results = auditors["calibration"].analyze_calibration(test_images, n_bins=n_bins)
# Create summary message
metrics = results["metrics"]
summary = (
f"π Calibration Analysis Complete!\n"
f"β’ Mean confidence: {metrics['mean_confidence']:.3f}\n"
f"β’ Overconfident rate: {metrics['overconfident_rate']:.2%}\n"
f"β’ Underconfident rate: {metrics['underconfident_rate']:.2%}"
)
return results["figure"], summary
except Exception as e:
error_msg = f"β Calibration analysis failed: {str(e)}"
print(error_msg)
return None, error_msg
def analyze_bias_detection(image, model_choice):
"""
Bias detection analysis for Tab 4.
"""
try:
# Load model if needed
model_status = load_selected_model(SUPPORTED_MODELS[model_choice])
if "β" in model_status:
return None, None, model_status
if image is None:
return None, None, "β οΈ Please upload an image first."
processed_image = preprocess_image(image)
# Create demo subgroups based on the uploaded image
# In a real scenario, you'd use predefined subgroups from your dataset
subsets = []
subset_names = ["Original", "Brightness+", "Brightness-", "Contrast+"]
# Original image
subsets.append([processed_image])
# Brightness increased
bright_image = processed_image.copy().point(lambda p: min(255, p * 1.5))
subsets.append([bright_image])
# Brightness decreased
dark_image = processed_image.copy().point(lambda p: p * 0.7)
subsets.append([dark_image])
# Contrast increased
contrast_image = processed_image.copy().point(lambda p: 128 + (p - 128) * 1.5)
subsets.append([contrast_image])
# Perform bias analysis
results = auditors["bias"].analyze_subgroup_performance(subsets, subset_names)
# Create summary message
subgroup_metrics = results["subgroup_metrics"]
summary = f"βοΈ Bias Detection Complete!\nAnalyzed {len(subgroup_metrics)} subgroups:\n"
for name, metrics in subgroup_metrics.items():
summary += f"β’ {name}: confidence={metrics['mean_confidence']:.3f}\n"
return results["figure"], summary
except Exception as e:
error_msg = f"β Bias detection failed: {str(e)}"
print(error_msg)
return None, error_msg
def create_demo_image():
"""Create a demo image for first-time users."""
# Create a simple demo image with multiple colors
img = Image.new("RGB", (224, 224), color=(150, 100, 100))
# Add different colored regions
for x in range(50, 150):
for y in range(50, 150):
img.putpixel((x, y), (100, 200, 100)) # Green square
for x in range(160, 200):
for y in range(160, 200):
img.putpixel((x, y), (100, 100, 200)) # Blue square
return img
# Minimal CSS for basic styling without breaking functionality
custom_css = """
/* Basic styling without interfering with dropdowns */
.gradio-container {
background: linear-gradient(135deg, #0f1419 0%, #1a1f2e 50%, #0f1419 100%);
font-family: 'Inter', sans-serif;
}
/* Header styling */
.main-header {
background: rgba(99, 102, 241, 0.05);
border-radius: 20px;
padding: 2.5rem;
margin-bottom: 2rem;
}
/* Button styling */
button.primary {
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
border: none;
color: white;
font-weight: 600;
padding: 14px 32px;
border-radius: 12px;
}
button.primary:hover {
transform: translateY(-2px);
box-shadow: 0 6px 24px rgba(99, 102, 241, 0.6);
}
/* Block styling */
.block {
background: rgba(30, 41, 59, 0.4);
border-radius: 16px;
padding: 1.5rem;
border: 1px solid rgba(99, 102, 241, 0.15);
}
/* Tab styling */
.tab-nav button {
background: rgba(30, 41, 59, 0.5);
border: 1px solid rgba(99, 102, 241, 0.2);
border-radius: 12px;
padding: 14px 28px;
margin: 0 6px;
color: #94a3b8;
font-weight: 600;
}
.tab-nav button.selected {
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
color: white;
}
"""
# Create the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="ViT Auditing Toolkit") as demo:
# Main Header
gr.HTML(
"""
<div class="main-header">
<h1 style="
font-size: 3rem;
font-weight: 800;
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 50%, #ec4899 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
margin-bottom: 0.5rem;
text-align: center;
">
π― ViT Auditing Toolkit
</h1>
<p style="
font-size: 1.25rem;
color: #94a3b8;
text-align: center;
font-weight: 500;
margin-bottom: 0;
">
Comprehensive Model Explainability and Validation Dashboard
</p>
</div>
"""
)
# About Section
gr.HTML(
"""
<div style="
background: rgba(30, 41, 59, 0.4);
border-radius: 16px;
padding: 2rem;
margin-bottom: 2rem;
border: 1px solid rgba(99, 102, 241, 0.15);
">
<h2 style="font-size: 1.75rem; font-weight: 700; color: #e0e7ff; margin-bottom: 1rem;">
βΉοΈ About This Toolkit
</h2>
<p style="color: #94a3b8; line-height: 1.8; font-size: 1.05rem; margin-bottom: 1.5rem;">
This interactive dashboard provides comprehensive auditing capabilities for Vision Transformer models,
enabling researchers and practitioners to understand, validate, and improve their AI models through
multiple explainability techniques.
</p>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(240px, 1fr)); gap: 1rem;">
<div style="background: rgba(99, 102, 241, 0.08); padding: 1.5rem; border-radius: 12px; border: 1px solid rgba(99, 102, 241, 0.2);">
<div style="font-size: 2rem; margin-bottom: 0.5rem;">π</div>
<strong style="color: #a5b4fc; font-size: 1.1rem;">Basic Explainability</strong>
<p style="margin-top: 0.5rem; font-size: 0.9rem; color: #94a3b8;">
Understand model predictions with attention maps, GradCAM, and SHAP visualizations
</p>
</div>
<div style="background: rgba(99, 102, 241, 0.08); padding: 1.5rem; border-radius: 12px; border: 1px solid rgba(99, 102, 241, 0.2);">
<div style="font-size: 2rem; margin-bottom: 0.5rem;">π</div>
<strong style="color: #c4b5fd; font-size: 1.1rem;">Counterfactual Analysis</strong>
<p style="margin-top: 0.5rem; font-size: 0.9rem; color: #94a3b8;">
Test prediction robustness by systematically perturbing image regions
</p>
</div>
<div style="background: rgba(99, 102, 241, 0.08); padding: 1.5rem; border-radius: 12px; border: 1px solid rgba(99, 102, 241, 0.2);">
<div style="font-size: 2rem; margin-bottom: 0.5rem;">π</div>
<strong style="color: #f9a8d4; font-size: 1.1rem;">Confidence Calibration</strong>
<p style="margin-top: 0.5rem; font-size: 0.9rem; color: #94a3b8;">
Evaluate whether model confidence scores accurately reflect prediction reliability
</p>
</div>
<div style="background: rgba(99, 102, 241, 0.08); padding: 1.5rem; border-radius: 12px; border: 1px solid rgba(99, 102, 241, 0.2);">
<div style="font-size: 2rem; margin-bottom: 0.5rem;">βοΈ</div>
<strong style="color: #93c5fd; font-size: 1.1rem;">Bias Detection</strong>
<p style="margin-top: 0.5rem; font-size: 0.9rem; color: #94a3b8;">
Identify performance variations across different demographic or data subgroups
</p>
</div>
</div>
</div>
"""
)
# Quick Start Guide
gr.HTML(
"""
<div style="
background: rgba(99, 102, 241, 0.1);
border-radius: 16px;
padding: 2rem;
margin-bottom: 2rem;
border: 1px solid rgba(99, 102, 241, 0.25);
">
<h2 style="font-size: 1.5rem; font-weight: 700; color: #e0e7ff; margin-bottom: 1.5rem;">
π Quick Start Guide
</h2>
<div style="display: grid; gap: 1rem;">
<div style="display: flex; align-items: start; gap: 1rem;">
<div style="
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
border-radius: 50%;
width: 32px;
height: 32px;
display: flex;
align-items: center;
justify-content: center;
font-weight: 700;
color: white;
flex-shrink: 0;
">1</div>
<div>
<strong style="color: #c4b5fd; font-size: 1.05rem;">Select a Model</strong>
<p style="color: #94a3b8; margin-top: 0.25rem; line-height: 1.6;">
Choose a Vision Transformer model from the dropdown and click "Load Model" button
</p>
</div>
</div>
<div style="display: flex; align-items: start; gap: 1rem;">
<div style="
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
border-radius: 50%;
width: 32px;
height: 32px;
display: flex;
align-items: center;
justify-content: center;
font-weight: 700;
color: white;
flex-shrink: 0;
">2</div>
<div>
<strong style="color: #c4b5fd; font-size: 1.05rem;">Upload Your Image</strong>
<p style="color: #94a3b8; margin-top: 0.25rem; line-height: 1.6;">
Navigate to any tab and upload an image you want to analyze
</p>
</div>
</div>
<div style="display: flex; align-items: start; gap: 1rem;">
<div style="
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
border-radius: 50%;
width: 32px;
height: 32px;
display: flex;
align-items: center;
justify-content: center;
font-weight: 700;
color: white;
flex-shrink: 0;
">3</div>
<div>
<strong style="color: #c4b5fd; font-size: 1.05rem;">Choose Analysis Type</strong>
<p style="color: #94a3b8; margin-top: 0.25rem; line-height: 1.6;">
Select from 4 tabs: Basic Explainability, Counterfactual Analysis, Confidence Calibration, or Bias Detection
</p>
</div>
</div>
<div style="display: flex; align-items: start; gap: 1rem;">
<div style="
background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%);
border-radius: 50%;
width: 32px;
height: 32px;
display: flex;
align-items: center;
justify-content: center;
font-weight: 700;
color: white;
flex-shrink: 0;
">4</div>
<div>
<strong style="color: #c4b5fd; font-size: 1.05rem;">Run Analysis</strong>
<p style="color: #94a3b8; margin-top: 0.25rem; line-height: 1.6;">
Adjust settings if needed, then click the analysis button to see results and visualizations
</p>
</div>
</div>
</div>
<div style="
margin-top: 1.5rem;
padding: 1rem;
background: rgba(139, 92, 246, 0.1);
border-radius: 12px;
border-left: 4px solid #8b5cf6;
">
<p style="color: #c4b5fd; margin: 0; font-size: 0.95rem;">
π‘ <strong>Tip:</strong> Start with "Basic Explainability" to understand what your model sees,
then explore advanced auditing features in other tabs.
</p>
</div>
</div>
"""
)
# Model selection (shared across all tabs)
with gr.Row():
with gr.Column(scale=3):
model_choice = gr.Dropdown(
choices=list(SUPPORTED_MODELS.keys()),
value="ViT-Base",
label="π― Select Model",
info="Choose which Vision Transformer model to use",
)
with gr.Column(scale=3):
model_status = gr.Textbox(
label="π‘ Model Status",
interactive=False,
placeholder="Select a model and click 'Load Model' to begin...",
)
with gr.Column(scale=2):
load_btn = gr.Button("π Load Model", variant="primary", size="lg")
load_btn.click(
fn=lambda model: load_selected_model(SUPPORTED_MODELS[model]),
inputs=[model_choice],
outputs=[model_status],
)
# Tabbed interface
with gr.Tabs():
# Tab 1: Basic Explainability
with gr.TabItem("π Basic Explainability"):
gr.Markdown(
"""
### Understanding Model Predictions
Visualize what the model "sees" and understand which features influence its decisions.
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
label="π Upload Image",
type="pil",
sources=["upload", "clipboard"],
height=350,
)
with gr.Accordion("βοΈ Explanation Settings", open=False):
xai_method = gr.Dropdown(
choices=["Attention Visualization", "GradCAM", "GradientSHAP"],
value="Attention Visualization",
label="π¬ Explanation Method",
info="Select the explainability technique to apply",
)
gr.Markdown("**Attention-specific Parameters:**")
with gr.Row():
layer_index = gr.Slider(
minimum=0,
maximum=11,
value=6,
step=1,
label="Layer Index",
info="Which transformer layer to visualize (0-11)",
)
with gr.Row():
head_index = gr.Slider(
minimum=0,
maximum=11,
value=0,
step=1,
label="Head Index",
info="Which attention head to visualize (0-11)",
)
analyze_btn = gr.Button("π Analyze Image", variant="primary", size="lg")
status_output = gr.Textbox(
label="π Analysis Status",
interactive=False,
placeholder="Upload an image and click 'Analyze Image' to start...",
lines=4,
max_lines=6,
)
with gr.Column(scale=2):
with gr.Row():
original_display = gr.Image(
label="πΈ Processed Image", interactive=False, height=300
)
prediction_display = gr.Plot(label="π Top Predictions")
explanation_display = gr.Plot(label="π Explanation Visualization")
# Connect the analyze button
analyze_btn.click(
fn=analyze_image_basic,
inputs=[image_input, model_choice, xai_method, layer_index, head_index],
outputs=[original_display, prediction_display, explanation_display, status_output],
)
# Tab 2: Counterfactual Analysis
with gr.TabItem("π Counterfactual Analysis"):
gr.Markdown(
"""
### Testing Model Robustness
Systematically perturb image regions to understand which areas are most critical for predictions.
"""
)
with gr.Row():
with gr.Column(scale=1):
cf_image_input = gr.Image(
label="π Upload Image",
type="pil",
sources=["upload", "clipboard"],
height=350,
)
with gr.Accordion("βοΈ Counterfactual Settings", open=True):
patch_size = gr.Slider(
minimum=16,
maximum=64,
value=32,
step=16,
label="π² Patch Size",
info="Size of perturbation patches - 16, 32, 48, or 64 pixels",
)
perturbation_type = gr.Dropdown(
choices=["blur", "blackout", "gray", "noise"],
value="blur",
label="π¨ Perturbation Type",
info="How to modify image patches",
)
gr.Markdown(
"""
**Perturbation Types:**
- **Blur**: Gaussian blur effect
- **Blackout**: Replace with black pixels
- **Gray**: Convert to grayscale
- **Noise**: Add random noise
"""
)
cf_analyze_btn = gr.Button(
"π Run Counterfactual Analysis", variant="primary", size="lg"
)
cf_status_output = gr.Textbox(
label="π Analysis Status",
interactive=False,
placeholder="Upload an image and click to start counterfactual analysis...",
lines=5,
max_lines=8,
)
with gr.Column(scale=2):
cf_explanation_display = gr.Plot(label="π Counterfactual Analysis Results")
gr.Markdown(
"""
**Understanding Results:**
- **Confidence Change**: How much the model's certainty shifts
- **Prediction Flip Rate**: Percentage of patches causing misclassification
- **Sensitive Regions**: Areas most critical to the model's decision
"""
)
cf_analyze_btn.click(
fn=analyze_counterfactual,
inputs=[cf_image_input, model_choice, patch_size, perturbation_type],
outputs=[cf_explanation_display, cf_status_output],
)
# Tab 3: Confidence Calibration
with gr.TabItem("π Confidence Calibration"):
gr.Markdown(
"""
### Evaluating Prediction Reliability
Assess whether the model's confidence scores accurately reflect the likelihood of correct predictions.
"""
)
with gr.Row():
with gr.Column(scale=1):
cal_image_input = gr.Image(
label="π Upload Sample Image",
type="pil",
sources=["upload", "clipboard"],
height=350,
)
with gr.Accordion("βοΈ Calibration Settings", open=True):
n_bins = gr.Slider(
minimum=5,
maximum=20,
value=10,
step=1,
label="π Number of Bins",
info="Granularity of calibration analysis (5-20)",
)
gr.Markdown(
"""
**Calibration Metrics:**
- **Perfect calibration**: Confidence matches accuracy
- **Overconfident**: High confidence, low accuracy
- **Underconfident**: Low confidence, high accuracy
"""
)
cal_analyze_btn = gr.Button(
"π Analyze Calibration", variant="primary", size="lg"
)
cal_status_output = gr.Textbox(
label="π Analysis Status",
interactive=False,
placeholder="Upload an image and click to analyze calibration...",
lines=5,
max_lines=8,
)
with gr.Column(scale=2):
cal_explanation_display = gr.Plot(label="π Calibration Analysis Results")
gr.Markdown(
"""
**Interpreting Calibration:**
- A well-calibrated model's confidence should match its accuracy
- If the model predicts 80% confidence, it should be correct 80% of the time
- Large deviations indicate calibration issues requiring attention
"""
)
cal_analyze_btn.click(
fn=analyze_calibration,
inputs=[cal_image_input, model_choice, n_bins],
outputs=[cal_explanation_display, cal_status_output],
)
# Tab 4: Bias Detection
with gr.TabItem("βοΈ Bias Detection"):
gr.Markdown(
"""
### Identifying Performance Disparities
Detect potential biases by comparing model performance across different data subgroups.
"""
)
with gr.Row():
with gr.Column(scale=1):
bias_image_input = gr.Image(
label="π Upload Sample Image",
type="pil",
sources=["upload", "clipboard"],
height=350,
)
gr.Markdown(
"""
**Generated Subgroups:**
- Original image (baseline)
- Increased brightness
- Decreased brightness
- Enhanced contrast
"""
)
bias_analyze_btn = gr.Button("βοΈ Detect Bias", variant="primary", size="lg")
bias_status_output = gr.Textbox(
label="π Analysis Status",
interactive=False,
placeholder="Upload an image and click to detect potential biases...",
lines=6,
max_lines=10,
)
with gr.Column(scale=2):
bias_explanation_display = gr.Plot(label="βοΈ Bias Detection Results")
gr.Markdown(
"""
**Understanding Bias Metrics:**
- Compare confidence scores across subgroups
- Large disparities may indicate systematic biases
- Consider demographic, environmental, and quality variations
- Use findings to improve data collection and model training
"""
)
bias_analyze_btn.click(
fn=analyze_bias_detection,
inputs=[bias_image_input, model_choice],
outputs=[bias_explanation_display, bias_status_output],
)
# Footer
gr.HTML(
"""
<div style="
margin-top: 3rem;
padding: 2rem;
background: rgba(30, 41, 59, 0.3);
border-top: 1px solid rgba(99, 102, 241, 0.2);
border-radius: 16px;
text-align: center;
">
<p style="
color: #64748b;
font-size: 0.95rem;
margin: 0;
">
Built with β€οΈ using <strong style="color: #a5b4fc;">Gradio</strong>,
<strong style="color: #c4b5fd;">Transformers</strong>, and
<strong style="color: #f9a8d4;">Captum</strong>
</p>
<p style="
color: #475569;
font-size: 0.85rem;
margin-top: 0.5rem;
">
Β© 2024 ViT Auditing Toolkit β’ For research and educational purposes
</p>
</div>
"""
)
# Launch the application
if __name__ == "__main__":
import os as _os
# Use dynamic host/port for portability (e.g., Hugging Face Spaces)
host = "0.0.0.0"
port = int(_os.environ.get("PORT", "7860"))
demo.launch(server_name=host, server_port=port, share=False, show_error=True)
|