Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| import argparse | |
| from inference import GenerativeInferenceModel, get_inference_configs | |
| # Parse command line arguments | |
| parser = argparse.ArgumentParser(description='Run Generative Inference Demo') | |
| parser.add_argument('--port', type=int, default=7860, help='Port to run the server on') | |
| args = parser.parse_args() | |
| # Create model directories if they don't exist | |
| os.makedirs("models", exist_ok=True) | |
| os.makedirs("stimuli", exist_ok=True) | |
| # Initialize model | |
| model = GenerativeInferenceModel() | |
| def run_inference(image, model_type, illusion_type, eps_value, num_iterations): | |
| # Convert eps to float | |
| eps = float(eps_value) | |
| # Load inference configuration | |
| config = get_inference_configs(eps=eps, n_itr=int(num_iterations)) | |
| # Run generative inference | |
| output_images, all_steps = model.inference(image, model_type, config) | |
| # Create animation frames | |
| frames = [] | |
| for i, step_image in enumerate(all_steps): | |
| # Convert tensor to PIL image | |
| step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) | |
| frames.append(step_pil) | |
| # Return the final inferred image and the animation | |
| return output_images, gr.Gallery.update(value=frames) | |
| # Define the interface | |
| with gr.Blocks(title="Generative Inference Demo") as demo: | |
| gr.Markdown("# Generative Inference Demo") | |
| gr.Markdown("This demo showcases how neural networks can perceive visual illusions through generative inference.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Inputs | |
| image_input = gr.Image(label="Upload Image or Select an Illusion", type="pil") | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=["robust_resnet50", "standard_resnet50"], | |
| value="robust_resnet50", | |
| label="Model" | |
| ) | |
| illusion_type = gr.Dropdown( | |
| choices=["Kanizsa", "Face-Vase", "Neon-Color", "Figure-Ground"], | |
| value="Kanizsa", | |
| label="Illusion Type" | |
| ) | |
| with gr.Row(): | |
| eps_slider = gr.Slider(minimum=0.01, maximum=3.0, value=0.5, step=0.01, label="Epsilon (Perturbation Size)") | |
| iterations_slider = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Number of Iterations") | |
| run_button = gr.Button("Run Inference") | |
| with gr.Column(scale=2): | |
| # Outputs | |
| output_image = gr.Image(label="Final Inferred Image") | |
| output_frames = gr.Gallery(label="Inference Steps", columns=4, rows=2) | |
| # Set up example images | |
| examples = [ | |
| [os.path.join("stimuli", "Kanizsa_square.jpg"), "robust_resnet50", "Kanizsa", 0.5, 50], | |
| [os.path.join("stimuli", "face_vase.png"), "robust_resnet50", "Face-Vase", 0.5, 50], | |
| [os.path.join("stimuli", "figure_ground.png"), "robust_resnet50", "Figure-Ground", 0.7, 100], | |
| [os.path.join("stimuli", "NeonColorSaeedi.jpg"), "robust_resnet50", "Neon-Color", 0.3, 80] | |
| ] | |
| gr.Examples(examples=examples, inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider]) | |
| # Set up event handler | |
| run_button.click( | |
| fn=run_inference, | |
| inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider], | |
| outputs=[output_image, output_frames] | |
| ) | |
| # Include a description of the technique | |
| gr.Markdown(""" | |
| ## About Generative Inference | |
| Generative inference is a technique that reveals how neural networks perceive visual stimuli by optimizing the input | |
| to increase the network's confidence in its predictions. This process can reveal emergent perception of contours, | |
| figure-ground separation, and other visual phenomena similar to human perception. | |
| This demo allows you to: | |
| 1. Upload your own images or select from example illusions | |
| 2. Choose between robust or standard models | |
| 3. Adjust parameters like perturbation size (epsilon) and number of iterations | |
| 4. Visualize how the perception emerges over time | |
| """) | |
| # Launch the demo with specific settings | |
| if __name__ == "__main__": | |
| print(f"Starting server on port {args.port}") | |
| demo.launch( | |
| server_name="0.0.0.0", # Listen on all interfaces | |
| server_port=args.port, # Use the port from command line arguments | |
| share=False, | |
| debug=True | |
| ) |