import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import torchvision.transforms as transforms from torchvision.models.resnet import ResNet50_Weights from PIL import Image import numpy as np import os import requests import time from pathlib import Path # Check CUDA availability device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Constants MODEL_URLS = { 'robust_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps_3.0.pt', 'standard_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps_0.0.pt' } IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # Default transform transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), ]) normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) # Get ImageNet labels def get_imagenet_labels(): url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" response = requests.get(url) if response.status_code == 200: return response.json() else: raise RuntimeError("Failed to fetch ImageNet labels") # Download model if needed def download_model(model_type): if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None: return None # Use PyTorch's pretrained model model_path = Path(f"models/{model_type}.pt") if not model_path.exists(): print(f"Downloading {model_type} model...") url = MODEL_URLS[model_type] response = requests.get(url, stream=True) if response.status_code == 200: with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Model downloaded and saved to {model_path}") else: raise RuntimeError(f"Failed to download model: {response.status_code}") return model_path class NormalizeByChannelMeanStd(nn.Module): def __init__(self, mean, std): super(NormalizeByChannelMeanStd, self).__init__() if not isinstance(mean, torch.Tensor): mean = torch.tensor(mean) if not isinstance(std, torch.Tensor): std = torch.tensor(std) self.register_buffer("mean", mean) self.register_buffer("std", std) def forward(self, tensor): return self.normalize_fn(tensor, self.mean, self.std) def normalize_fn(self, tensor, mean, std): """Differentiable version of torchvision.functional.normalize""" # here we assume the color channel is at dim=1 mean = mean[None, :, None, None] std = std[None, :, None, None] return tensor.sub(mean).div(std) class InferStep: def __init__(self, orig_image, eps, step_size): self.orig_image = orig_image self.eps = eps self.step_size = step_size def project(self, x): diff = x - self.orig_image diff = torch.clamp(diff, -self.eps, self.eps) return torch.clamp(self.orig_image + diff, 0, 1) def step(self, x, grad): l = len(x.shape) - 1 grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l)) scaled_grad = grad / (grad_norm + 1e-10) return scaled_grad * self.step_size def get_inference_configs(eps=0.5, n_itr=50): """Generate inference configuration with customizable parameters.""" config = { 'loss_infer': 'IncreaseConfidence', # How to guide the optimization 'loss_function': 'CE', # Loss function: Cross Entropy 'n_itr': n_itr, # Number of iterations 'eps': eps, # Maximum perturbation size 'step_size': 0.02, # Step size for each iteration 'diffusion_noise_ratio': 0.0, # No diffusion noise 'initial_inference_noise_ratio': 0.0, # No initial noise 'top_layer': 'all', # Use all layers of the model 'inference_normalization': 'on', # Apply normalization during inference 'recognition_normalization': 'on', # Apply normalization during recognition 'iterations_to_show': [1, 5, 10, 20, 30, 40, 50, n_itr] # Specific iterations to visualize } return config class GenerativeInferenceModel: def __init__(self): self.models = {} self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device) self.labels = get_imagenet_labels() def load_model(self, model_type): if model_type in self.models: return self.models[model_type] model_path = download_model(model_type) # Create standard ResNet50 model model = models.resnet50() # Load the model checkpoint if model_path: print(f"Loading {model_type} model from {model_path}...") checkpoint = torch.load(model_path, map_location=device) # Handle different checkpoint formats if 'model' in checkpoint: # Format from madrylab robust models state_dict = checkpoint['model'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: # Direct state dict state_dict = checkpoint # Handle prefix in state dict keys new_state_dict = {} for key, value in state_dict.items(): if key.startswith('module.'): new_key = key[7:] # Remove 'module.' prefix else: new_key = key new_state_dict[new_key] = value model.load_state_dict(new_state_dict) else: # Fallback to PyTorch's pretrained model model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) model = model.to(device) model.eval() # Set to evaluation mode # Store the model for future use self.models[model_type] = model return model def inference(self, image, model_type, config): # Load model if not already loaded model = self.load_model(model_type) # Check if image is a file path if isinstance(image, str): if os.path.exists(image): image = Image.open(image).convert('RGB') else: raise ValueError(f"Image path does not exist: {image}") # Prepare image tensor image_tensor = transform(image).unsqueeze(0).to(device) image_tensor.requires_grad = True # Normalize the image for model input normalized_tensor = normalize_transform(image_tensor) # Get original predictions with torch.no_grad(): output_original = model(normalized_tensor) probs_orig = F.softmax(output_original, dim=1) conf_orig, classes_orig = torch.max(probs_orig, 1) # Get least confident classes _, least_confident_classes = torch.topk(probs_orig, k=100, largest=False) # Initialize inference step infer_step = InferStep(image_tensor, config['eps'], config['step_size']) # Storage for inference steps x = image_tensor.clone() all_steps = [image_tensor[0].detach().cpu()] # Main inference loop for i in range(config['n_itr']): # Reset gradients x.grad = None # Normalize input for the model normalized_x = normalize_transform(x) # Forward pass output = model(normalized_x) # Calculate loss to maximize confidence for least confident classes target_classes = least_confident_classes[:10] # Use top 10 least confident classes loss = 0 for idx in target_classes: target = torch.tensor([idx.item()], device=device) loss = loss - F.cross_entropy(output, target) # Negative because we want to maximize confidence # Backward pass loss.backward() # Update image with torch.no_grad(): step = infer_step.step(x, x.grad) x = x + step x = infer_step.project(x) # Store step if in iterations_to_show if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']: all_steps.append(x[0].detach().cpu()) # Return final image and all stored steps return x[0].detach().cpu(), all_steps # Utility function to show inference steps def show_inference_steps(steps, figsize=(15, 10)): import matplotlib.pyplot as plt n_steps = len(steps) fig, axes = plt.subplots(1, n_steps, figsize=figsize) for i, step_img in enumerate(steps): img = step_img.permute(1, 2, 0).numpy() axes[i].imshow(img) axes[i].set_title(f"Step {i}") axes[i].axis('off') plt.tight_layout() return fig