import gradio as gr import torch import numpy as np from PIL import Image import os import argparse from inference import GenerativeInferenceModel, get_inference_configs # Parse command line arguments parser = argparse.ArgumentParser(description='Run Generative Inference Demo') parser.add_argument('--port', type=int, default=7860, help='Port to run the server on') args = parser.parse_args() # Create model directories if they don't exist os.makedirs("models", exist_ok=True) os.makedirs("stimuli", exist_ok=True) # Initialize model model = GenerativeInferenceModel() def run_inference(image, model_type, illusion_type, eps_value, num_iterations): # Convert eps to float eps = float(eps_value) # Load inference configuration config = get_inference_configs(eps=eps, n_itr=int(num_iterations)) # Run generative inference output_images, all_steps = model.inference(image, model_type, config) # Create animation frames frames = [] for i, step_image in enumerate(all_steps): # Convert tensor to PIL image step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) frames.append(step_pil) # Return the final inferred image and the animation return output_images, gr.Gallery.update(value=frames) # Define the interface with gr.Blocks(title="Generative Inference Demo") as demo: gr.Markdown("# Generative Inference Demo") gr.Markdown("This demo showcases how neural networks can perceive visual illusions through generative inference.") with gr.Row(): with gr.Column(scale=1): # Inputs image_input = gr.Image(label="Upload Image or Select an Illusion", type="pil") with gr.Row(): model_choice = gr.Dropdown( choices=["robust_resnet50", "standard_resnet50"], value="robust_resnet50", label="Model" ) illusion_type = gr.Dropdown( choices=["Kanizsa", "Face-Vase", "Neon-Color", "Figure-Ground"], value="Kanizsa", label="Illusion Type" ) with gr.Row(): eps_slider = gr.Slider(minimum=0.01, maximum=3.0, value=0.5, step=0.01, label="Epsilon (Perturbation Size)") iterations_slider = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Number of Iterations") run_button = gr.Button("Run Inference") with gr.Column(scale=2): # Outputs output_image = gr.Image(label="Final Inferred Image") output_frames = gr.Gallery(label="Inference Steps", columns=4, rows=2) # Set up example images examples = [ [os.path.join("stimuli", "Kanizsa_square.jpg"), "robust_resnet50", "Kanizsa", 0.5, 50], [os.path.join("stimuli", "face_vase.png"), "robust_resnet50", "Face-Vase", 0.5, 50], [os.path.join("stimuli", "figure_ground.png"), "robust_resnet50", "Figure-Ground", 0.7, 100], [os.path.join("stimuli", "NeonColorSaeedi.jpg"), "robust_resnet50", "Neon-Color", 0.3, 80] ] gr.Examples(examples=examples, inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider]) # Set up event handler run_button.click( fn=run_inference, inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider], outputs=[output_image, output_frames] ) # Include a description of the technique gr.Markdown(""" ## About Generative Inference Generative inference is a technique that reveals how neural networks perceive visual stimuli by optimizing the input to increase the network's confidence in its predictions. This process can reveal emergent perception of contours, figure-ground separation, and other visual phenomena similar to human perception. This demo allows you to: 1. Upload your own images or select from example illusions 2. Choose between robust or standard models 3. Adjust parameters like perturbation size (epsilon) and number of iterations 4. Visualize how the perception emerges over time """) # Launch the demo with specific settings if __name__ == "__main__": print(f"Starting server on port {args.port}") demo.launch( server_name="0.0.0.0", # Listen on all interfaces server_port=args.port, # Use the port from command line arguments share=False, debug=True )