Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from torchvision.models.resnet import ResNet50_Weights | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import requests | |
| import time | |
| from pathlib import Path | |
| # Check CUDA availability | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Constants | |
| MODEL_URLS = { | |
| 'robust_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps_3.0.pt', | |
| 'standard_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps_0.0.pt' | |
| } | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| # Default transform | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| ]) | |
| normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) | |
| # Get ImageNet labels | |
| def get_imagenet_labels(): | |
| url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| return response.json() | |
| else: | |
| raise RuntimeError("Failed to fetch ImageNet labels") | |
| # Download model if needed | |
| def download_model(model_type): | |
| if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None: | |
| return None # Use PyTorch's pretrained model | |
| model_path = Path(f"models/{model_type}.pt") | |
| if not model_path.exists(): | |
| print(f"Downloading {model_type} model...") | |
| url = MODEL_URLS[model_type] | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| with open(model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Model downloaded and saved to {model_path}") | |
| else: | |
| raise RuntimeError(f"Failed to download model: {response.status_code}") | |
| return model_path | |
| class NormalizeByChannelMeanStd(nn.Module): | |
| def __init__(self, mean, std): | |
| super(NormalizeByChannelMeanStd, self).__init__() | |
| if not isinstance(mean, torch.Tensor): | |
| mean = torch.tensor(mean) | |
| if not isinstance(std, torch.Tensor): | |
| std = torch.tensor(std) | |
| self.register_buffer("mean", mean) | |
| self.register_buffer("std", std) | |
| def forward(self, tensor): | |
| return self.normalize_fn(tensor, self.mean, self.std) | |
| def normalize_fn(self, tensor, mean, std): | |
| """Differentiable version of torchvision.functional.normalize""" | |
| # here we assume the color channel is at dim=1 | |
| mean = mean[None, :, None, None] | |
| std = std[None, :, None, None] | |
| return tensor.sub(mean).div(std) | |
| class InferStep: | |
| def __init__(self, orig_image, eps, step_size): | |
| self.orig_image = orig_image | |
| self.eps = eps | |
| self.step_size = step_size | |
| def project(self, x): | |
| diff = x - self.orig_image | |
| diff = torch.clamp(diff, -self.eps, self.eps) | |
| return torch.clamp(self.orig_image + diff, 0, 1) | |
| def step(self, x, grad): | |
| l = len(x.shape) - 1 | |
| grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l)) | |
| scaled_grad = grad / (grad_norm + 1e-10) | |
| return scaled_grad * self.step_size | |
| def get_inference_configs(eps=0.5, n_itr=50): | |
| """Generate inference configuration with customizable parameters.""" | |
| config = { | |
| 'loss_infer': 'IncreaseConfidence', # How to guide the optimization | |
| 'loss_function': 'CE', # Loss function: Cross Entropy | |
| 'n_itr': n_itr, # Number of iterations | |
| 'eps': eps, # Maximum perturbation size | |
| 'step_size': 0.02, # Step size for each iteration | |
| 'diffusion_noise_ratio': 0.0, # No diffusion noise | |
| 'initial_inference_noise_ratio': 0.0, # No initial noise | |
| 'top_layer': 'all', # Use all layers of the model | |
| 'inference_normalization': 'on', # Apply normalization during inference | |
| 'recognition_normalization': 'on', # Apply normalization during recognition | |
| 'iterations_to_show': [1, 5, 10, 20, 30, 40, 50, n_itr] # Specific iterations to visualize | |
| } | |
| return config | |
| class GenerativeInferenceModel: | |
| def __init__(self): | |
| self.models = {} | |
| self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device) | |
| self.labels = get_imagenet_labels() | |
| def load_model(self, model_type): | |
| if model_type in self.models: | |
| return self.models[model_type] | |
| model_path = download_model(model_type) | |
| # Create standard ResNet50 model | |
| model = models.resnet50() | |
| # Load the model checkpoint | |
| if model_path: | |
| print(f"Loading {model_type} model from {model_path}...") | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Handle different checkpoint formats | |
| if 'model' in checkpoint: | |
| # Format from madrylab robust models | |
| state_dict = checkpoint['model'] | |
| elif 'state_dict' in checkpoint: | |
| state_dict = checkpoint['state_dict'] | |
| else: | |
| # Direct state dict | |
| state_dict = checkpoint | |
| # Handle prefix in state dict keys | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith('module.'): | |
| new_key = key[7:] # Remove 'module.' prefix | |
| else: | |
| new_key = key | |
| new_state_dict[new_key] = value | |
| model.load_state_dict(new_state_dict) | |
| else: | |
| # Fallback to PyTorch's pretrained model | |
| model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) | |
| model = model.to(device) | |
| model.eval() # Set to evaluation mode | |
| # Store the model for future use | |
| self.models[model_type] = model | |
| return model | |
| def inference(self, image, model_type, config): | |
| # Load model if not already loaded | |
| model = self.load_model(model_type) | |
| # Check if image is a file path | |
| if isinstance(image, str): | |
| if os.path.exists(image): | |
| image = Image.open(image).convert('RGB') | |
| else: | |
| raise ValueError(f"Image path does not exist: {image}") | |
| # Prepare image tensor | |
| image_tensor = transform(image).unsqueeze(0).to(device) | |
| image_tensor.requires_grad = True | |
| # Normalize the image for model input | |
| normalized_tensor = normalize_transform(image_tensor) | |
| # Get original predictions | |
| with torch.no_grad(): | |
| output_original = model(normalized_tensor) | |
| probs_orig = F.softmax(output_original, dim=1) | |
| conf_orig, classes_orig = torch.max(probs_orig, 1) | |
| # Get least confident classes | |
| _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False) | |
| # Initialize inference step | |
| infer_step = InferStep(image_tensor, config['eps'], config['step_size']) | |
| # Storage for inference steps | |
| x = image_tensor.clone() | |
| all_steps = [image_tensor[0].detach().cpu()] | |
| # Main inference loop | |
| for i in range(config['n_itr']): | |
| # Reset gradients | |
| x.grad = None | |
| # Normalize input for the model | |
| normalized_x = normalize_transform(x) | |
| # Forward pass | |
| output = model(normalized_x) | |
| # Calculate loss to maximize confidence for least confident classes | |
| target_classes = least_confident_classes[:10] # Use top 10 least confident classes | |
| loss = 0 | |
| for idx in target_classes: | |
| target = torch.tensor([idx.item()], device=device) | |
| loss = loss - F.cross_entropy(output, target) # Negative because we want to maximize confidence | |
| # Backward pass | |
| loss.backward() | |
| # Update image | |
| with torch.no_grad(): | |
| step = infer_step.step(x, x.grad) | |
| x = x + step | |
| x = infer_step.project(x) | |
| # Store step if in iterations_to_show | |
| if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']: | |
| all_steps.append(x[0].detach().cpu()) | |
| # Return final image and all stored steps | |
| return x[0].detach().cpu(), all_steps | |
| # Utility function to show inference steps | |
| def show_inference_steps(steps, figsize=(15, 10)): | |
| import matplotlib.pyplot as plt | |
| n_steps = len(steps) | |
| fig, axes = plt.subplots(1, n_steps, figsize=figsize) | |
| for i, step_img in enumerate(steps): | |
| img = step_img.permute(1, 2, 0).numpy() | |
| axes[i].imshow(img) | |
| axes[i].set_title(f"Step {i}") | |
| axes[i].axis('off') | |
| plt.tight_layout() | |
| return fig |