ttoosi's picture
Upload 11 files
7449d44 verified
raw
history blame
4.63 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
import os
import argparse
from inference import GenerativeInferenceModel, get_inference_configs
# Parse command line arguments
parser = argparse.ArgumentParser(description='Run Generative Inference Demo')
parser.add_argument('--port', type=int, default=7860, help='Port to run the server on')
args = parser.parse_args()
# Create model directories if they don't exist
os.makedirs("models", exist_ok=True)
os.makedirs("stimuli", exist_ok=True)
# Initialize model
model = GenerativeInferenceModel()
def run_inference(image, model_type, illusion_type, eps_value, num_iterations):
# Convert eps to float
eps = float(eps_value)
# Load inference configuration
config = get_inference_configs(eps=eps, n_itr=int(num_iterations))
# Run generative inference
output_images, all_steps = model.inference(image, model_type, config)
# Create animation frames
frames = []
for i, step_image in enumerate(all_steps):
# Convert tensor to PIL image
step_pil = Image.fromarray((step_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
frames.append(step_pil)
# Return the final inferred image and the animation
return output_images, gr.Gallery.update(value=frames)
# Define the interface
with gr.Blocks(title="Generative Inference Demo") as demo:
gr.Markdown("# Generative Inference Demo")
gr.Markdown("This demo showcases how neural networks can perceive visual illusions through generative inference.")
with gr.Row():
with gr.Column(scale=1):
# Inputs
image_input = gr.Image(label="Upload Image or Select an Illusion", type="pil")
with gr.Row():
model_choice = gr.Dropdown(
choices=["robust_resnet50", "standard_resnet50"],
value="robust_resnet50",
label="Model"
)
illusion_type = gr.Dropdown(
choices=["Kanizsa", "Face-Vase", "Neon-Color", "Figure-Ground"],
value="Kanizsa",
label="Illusion Type"
)
with gr.Row():
eps_slider = gr.Slider(minimum=0.01, maximum=3.0, value=0.5, step=0.01, label="Epsilon (Perturbation Size)")
iterations_slider = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Number of Iterations")
run_button = gr.Button("Run Inference")
with gr.Column(scale=2):
# Outputs
output_image = gr.Image(label="Final Inferred Image")
output_frames = gr.Gallery(label="Inference Steps", columns=4, rows=2)
# Set up example images
examples = [
[os.path.join("stimuli", "Kanizsa_square.jpg"), "robust_resnet50", "Kanizsa", 0.5, 50],
[os.path.join("stimuli", "face_vase.png"), "robust_resnet50", "Face-Vase", 0.5, 50],
[os.path.join("stimuli", "figure_ground.png"), "robust_resnet50", "Figure-Ground", 0.7, 100],
[os.path.join("stimuli", "NeonColorSaeedi.jpg"), "robust_resnet50", "Neon-Color", 0.3, 80]
]
gr.Examples(examples=examples, inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider])
# Set up event handler
run_button.click(
fn=run_inference,
inputs=[image_input, model_choice, illusion_type, eps_slider, iterations_slider],
outputs=[output_image, output_frames]
)
# Include a description of the technique
gr.Markdown("""
## About Generative Inference
Generative inference is a technique that reveals how neural networks perceive visual stimuli by optimizing the input
to increase the network's confidence in its predictions. This process can reveal emergent perception of contours,
figure-ground separation, and other visual phenomena similar to human perception.
This demo allows you to:
1. Upload your own images or select from example illusions
2. Choose between robust or standard models
3. Adjust parameters like perturbation size (epsilon) and number of iterations
4. Visualize how the perception emerges over time
""")
# Launch the demo with specific settings
if __name__ == "__main__":
print(f"Starting server on port {args.port}")
demo.launch(
server_name="0.0.0.0", # Listen on all interfaces
server_port=args.port, # Use the port from command line arguments
share=False,
debug=True
)