GenerativeInferenceDemo / inference.py
ttoosi's picture
Upload 11 files
7449d44 verified
raw
history blame
9.36 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models.resnet import ResNet50_Weights
from PIL import Image
import numpy as np
import os
import requests
import time
from pathlib import Path
# Check CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Constants
MODEL_URLS = {
'robust_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps_3.0.pt',
'standard_resnet50': 'https://huggingface.co/madrylab/robust-imagenet-models/resolve/main/resnet50_l2_eps_0.0.pt'
}
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# Default transform
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
normalize_transform = transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
# Get ImageNet labels
def get_imagenet_labels():
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
response = requests.get(url)
if response.status_code == 200:
return response.json()
else:
raise RuntimeError("Failed to fetch ImageNet labels")
# Download model if needed
def download_model(model_type):
if model_type not in MODEL_URLS or MODEL_URLS[model_type] is None:
return None # Use PyTorch's pretrained model
model_path = Path(f"models/{model_type}.pt")
if not model_path.exists():
print(f"Downloading {model_type} model...")
url = MODEL_URLS[model_type]
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Model downloaded and saved to {model_path}")
else:
raise RuntimeError(f"Failed to download model: {response.status_code}")
return model_path
class NormalizeByChannelMeanStd(nn.Module):
def __init__(self, mean, std):
super(NormalizeByChannelMeanStd, self).__init__()
if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, tensor):
return self.normalize_fn(tensor, self.mean, self.std)
def normalize_fn(self, tensor, mean, std):
"""Differentiable version of torchvision.functional.normalize"""
# here we assume the color channel is at dim=1
mean = mean[None, :, None, None]
std = std[None, :, None, None]
return tensor.sub(mean).div(std)
class InferStep:
def __init__(self, orig_image, eps, step_size):
self.orig_image = orig_image
self.eps = eps
self.step_size = step_size
def project(self, x):
diff = x - self.orig_image
diff = torch.clamp(diff, -self.eps, self.eps)
return torch.clamp(self.orig_image + diff, 0, 1)
def step(self, x, grad):
l = len(x.shape) - 1
grad_norm = torch.norm(grad.view(grad.shape[0], -1), dim=1).view(-1, *([1]*l))
scaled_grad = grad / (grad_norm + 1e-10)
return scaled_grad * self.step_size
def get_inference_configs(eps=0.5, n_itr=50):
"""Generate inference configuration with customizable parameters."""
config = {
'loss_infer': 'IncreaseConfidence', # How to guide the optimization
'loss_function': 'CE', # Loss function: Cross Entropy
'n_itr': n_itr, # Number of iterations
'eps': eps, # Maximum perturbation size
'step_size': 0.02, # Step size for each iteration
'diffusion_noise_ratio': 0.0, # No diffusion noise
'initial_inference_noise_ratio': 0.0, # No initial noise
'top_layer': 'all', # Use all layers of the model
'inference_normalization': 'on', # Apply normalization during inference
'recognition_normalization': 'on', # Apply normalization during recognition
'iterations_to_show': [1, 5, 10, 20, 30, 40, 50, n_itr] # Specific iterations to visualize
}
return config
class GenerativeInferenceModel:
def __init__(self):
self.models = {}
self.normalizer = NormalizeByChannelMeanStd(IMAGENET_MEAN, IMAGENET_STD).to(device)
self.labels = get_imagenet_labels()
def load_model(self, model_type):
if model_type in self.models:
return self.models[model_type]
model_path = download_model(model_type)
# Create standard ResNet50 model
model = models.resnet50()
# Load the model checkpoint
if model_path:
print(f"Loading {model_type} model from {model_path}...")
checkpoint = torch.load(model_path, map_location=device)
# Handle different checkpoint formats
if 'model' in checkpoint:
# Format from madrylab robust models
state_dict = checkpoint['model']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
# Direct state dict
state_dict = checkpoint
# Handle prefix in state dict keys
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('module.'):
new_key = key[7:] # Remove 'module.' prefix
else:
new_key = key
new_state_dict[new_key] = value
model.load_state_dict(new_state_dict)
else:
# Fallback to PyTorch's pretrained model
model = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model = model.to(device)
model.eval() # Set to evaluation mode
# Store the model for future use
self.models[model_type] = model
return model
def inference(self, image, model_type, config):
# Load model if not already loaded
model = self.load_model(model_type)
# Check if image is a file path
if isinstance(image, str):
if os.path.exists(image):
image = Image.open(image).convert('RGB')
else:
raise ValueError(f"Image path does not exist: {image}")
# Prepare image tensor
image_tensor = transform(image).unsqueeze(0).to(device)
image_tensor.requires_grad = True
# Normalize the image for model input
normalized_tensor = normalize_transform(image_tensor)
# Get original predictions
with torch.no_grad():
output_original = model(normalized_tensor)
probs_orig = F.softmax(output_original, dim=1)
conf_orig, classes_orig = torch.max(probs_orig, 1)
# Get least confident classes
_, least_confident_classes = torch.topk(probs_orig, k=100, largest=False)
# Initialize inference step
infer_step = InferStep(image_tensor, config['eps'], config['step_size'])
# Storage for inference steps
x = image_tensor.clone()
all_steps = [image_tensor[0].detach().cpu()]
# Main inference loop
for i in range(config['n_itr']):
# Reset gradients
x.grad = None
# Normalize input for the model
normalized_x = normalize_transform(x)
# Forward pass
output = model(normalized_x)
# Calculate loss to maximize confidence for least confident classes
target_classes = least_confident_classes[:10] # Use top 10 least confident classes
loss = 0
for idx in target_classes:
target = torch.tensor([idx.item()], device=device)
loss = loss - F.cross_entropy(output, target) # Negative because we want to maximize confidence
# Backward pass
loss.backward()
# Update image
with torch.no_grad():
step = infer_step.step(x, x.grad)
x = x + step
x = infer_step.project(x)
# Store step if in iterations_to_show
if i+1 in config['iterations_to_show'] or i+1 == config['n_itr']:
all_steps.append(x[0].detach().cpu())
# Return final image and all stored steps
return x[0].detach().cpu(), all_steps
# Utility function to show inference steps
def show_inference_steps(steps, figsize=(15, 10)):
import matplotlib.pyplot as plt
n_steps = len(steps)
fig, axes = plt.subplots(1, n_steps, figsize=figsize)
for i, step_img in enumerate(steps):
img = step_img.permute(1, 2, 0).numpy()
axes[i].imshow(img)
axes[i].set_title(f"Step {i}")
axes[i].axis('off')
plt.tight_layout()
return fig