Spaces:
Running
Running
| """ | |
| Explainability: Grad-CAM heatmaps and t-SNE feature visualization. | |
| Usage: | |
| python -m src.explainability --gradcam | |
| python -m src.explainability --tsne | |
| python -m src.explainability --all | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from sklearn.manifold import TSNE | |
| from tqdm import tqdm | |
| def reshape_transform(tensor, height=14, width=14): | |
| """Reshape ViT output from (B, 197, 768) to (B, 768, 14, 14) for Grad-CAM. | |
| ViT outputs a sequence of patch tokens. We drop the CLS token and reshape | |
| the remaining 196 tokens into a 14x14 spatial grid.""" | |
| # Remove CLS token (first token) | |
| result = tensor[:, 1:, :] | |
| # Reshape to spatial grid: (B, 196, 768) -> (B, 14, 14, 768) -> (B, 768, 14, 14) | |
| result = result.reshape(result.size(0), height, width, result.size(2)) | |
| result = result.permute(0, 3, 1, 2) | |
| return result | |
| from src.dataset import DateFruitDataset, get_val_transforms | |
| from src.models.vit_pretrained import PretrainedViTClassifier | |
| from src.utils import load_config, get_device, seed_everything | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| CLASS_NAMES = [ | |
| "Ajwa", "Galaxy", "Medjool", "Meneifi", "Nabtat Ali", | |
| "Rutab", "Shaishe", "Sokari", "Sugaey", | |
| ] | |
| def load_vit_model(checkpoint_path: str, device: torch.device) -> PretrainedViTClassifier: | |
| """Load trained ViT from checkpoint.""" | |
| model = PretrainedViTClassifier(num_classes=9) | |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| model = model.to(device) | |
| model.eval() | |
| return model | |
| def denormalize(tensor: torch.Tensor) -> np.ndarray: | |
| """Convert normalized tensor back to 0-1 RGB numpy array.""" | |
| mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1) | |
| std = torch.tensor(IMAGENET_STD).view(3, 1, 1) | |
| img = tensor.cpu() * std + mean | |
| img = img.clamp(0, 1).permute(1, 2, 0).numpy() | |
| return img | |
| def generate_gradcam( | |
| model: PretrainedViTClassifier, | |
| dataset: DateFruitDataset, | |
| device: torch.device, | |
| samples_per_class: int = 2, | |
| save_path: str = "results/gradcam_grid.png", | |
| ) -> None: | |
| """Generate Grad-CAM heatmap grid for each variety.""" | |
| # ViT target layer: last layernorm before attention | |
| target_layer = model.backbone.vit.encoder.layer[-1].layernorm_before | |
| cam = GradCAM(model=model, target_layers=[target_layer], reshape_transform=reshape_transform) | |
| num_classes = len(CLASS_NAMES) | |
| fig, axes = plt.subplots( | |
| num_classes, samples_per_class * 2, | |
| figsize=(4 * samples_per_class * 2, 3 * num_classes), | |
| ) | |
| # Group images by class | |
| class_indices = {v: [] for v in CLASS_NAMES} | |
| for idx in range(len(dataset)): | |
| _, label, variety = dataset[idx] | |
| if len(class_indices[variety]) < samples_per_class: | |
| class_indices[variety].append(idx) | |
| for row, variety in enumerate(CLASS_NAMES): | |
| indices = class_indices.get(variety, []) | |
| for col, idx in enumerate(indices): | |
| image_tensor, label, _ = dataset[idx] | |
| input_tensor = image_tensor.unsqueeze(0).to(device) | |
| grayscale_cam = cam(input_tensor=input_tensor, targets=None) | |
| grayscale_cam = grayscale_cam[0, :] | |
| rgb_img = denormalize(image_tensor) | |
| # Original image | |
| ax_orig = axes[row, col * 2] | |
| ax_orig.imshow(rgb_img) | |
| ax_orig.set_title(variety, fontsize=10, fontweight="bold") | |
| ax_orig.axis("off") | |
| # Grad-CAM overlay | |
| cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
| ax_cam = axes[row, col * 2 + 1] | |
| ax_cam.imshow(cam_image) | |
| ax_cam.set_title("Grad-CAM", fontsize=10) | |
| ax_cam.axis("off") | |
| plt.suptitle("Grad-CAM: What the ViT Model Focuses On", fontsize=16, fontweight="bold", y=1.01) | |
| plt.tight_layout() | |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f"Saved: {save_path}") | |
| def generate_tsne( | |
| model: PretrainedViTClassifier, | |
| dataset: DateFruitDataset, | |
| device: torch.device, | |
| save_path: str = "results/tsne.png", | |
| ) -> None: | |
| """Generate t-SNE plot of learned feature embeddings.""" | |
| model.eval() | |
| all_features = [] | |
| all_labels = [] | |
| loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) | |
| with torch.no_grad(): | |
| for images, labels, _ in tqdm(loader, desc="Extracting features"): | |
| images = images.to(device) | |
| # Get CLS token features from ViT | |
| outputs = model.backbone.vit(pixel_values=images) | |
| features = outputs.last_hidden_state[:, 0, :] # CLS token | |
| all_features.append(features.cpu().numpy()) | |
| all_labels.append(labels.numpy()) | |
| features = np.concatenate(all_features) | |
| labels = np.concatenate(all_labels) | |
| print(f"Running t-SNE on {features.shape[0]} samples, {features.shape[1]} dimensions...") | |
| tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(features) - 1)) | |
| embeddings = tsne.fit_transform(features) | |
| fig, ax = plt.subplots(figsize=(12, 10)) | |
| colors = plt.cm.tab10(np.linspace(0, 1, len(CLASS_NAMES))) | |
| for i, variety in enumerate(CLASS_NAMES): | |
| mask = labels == i | |
| ax.scatter( | |
| embeddings[mask, 0], | |
| embeddings[mask, 1], | |
| c=[colors[i]], | |
| label=variety, | |
| alpha=0.7, | |
| s=50, | |
| edgecolors="white", | |
| linewidth=0.5, | |
| ) | |
| ax.legend(fontsize=10, loc="best") | |
| ax.set_title("t-SNE: How ViT Clusters Saudi Date Varieties", fontsize=14, fontweight="bold") | |
| ax.set_xlabel("t-SNE Dimension 1") | |
| ax.set_ylabel("t-SNE Dimension 2") | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| Path(save_path).parent.mkdir(parents=True, exist_ok=True) | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f"Saved: {save_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--gradcam", action="store_true") | |
| parser.add_argument("--tsne", action="store_true") | |
| parser.add_argument("--all", action="store_true") | |
| parser.add_argument("--checkpoint", type=str, default="checkpoints/vit_best_model.pth") | |
| args = parser.parse_args() | |
| if args.all: | |
| args.gradcam = True | |
| args.tsne = True | |
| if not args.gradcam and not args.tsne: | |
| print("Specify --gradcam, --tsne, or --all") | |
| return | |
| config = load_config() | |
| seed_everything(42) | |
| device = get_device() | |
| print(f"Device: {device}") | |
| print(f"Loading model from {args.checkpoint}...") | |
| model = load_vit_model(args.checkpoint, device) | |
| transform = get_val_transforms(config) | |
| dataset = DateFruitDataset("data/test.csv", transform=transform) | |
| print(f"Test set: {len(dataset)} images") | |
| if args.gradcam: | |
| print("\nGenerating Grad-CAM...") | |
| generate_gradcam(model, dataset, device) | |
| if args.tsne: | |
| print("\nGenerating t-SNE...") | |
| generate_tsne(model, dataset, device) | |
| if __name__ == "__main__": | |
| main() | |