Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # Copyright (C) 2025 NVIDIA Corporation. All rights reserved. | |
| # | |
| # This work is licensed under the LICENSE file | |
| # located at the root directory. | |
| import os | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import tempfile | |
| import gc | |
| from datetime import datetime | |
| from addit_flux_pipeline import AdditFluxPipeline | |
| from addit_flux_transformer import AdditFluxTransformer2DModel | |
| from addit_scheduler import AdditFlowMatchEulerDiscreteScheduler | |
| from addit_methods import add_object_generated, add_object_real | |
| # Global variables for model | |
| pipe = None | |
| device = None | |
| original_image_size = None | |
| # Initialize model at startup | |
| print("Initializing ADDIT model...") | |
| try: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load transformer | |
| my_transformer = AdditFluxTransformer2DModel.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| # Load pipeline | |
| pipe = AdditFluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-dev", | |
| transformer=my_transformer, | |
| torch_dtype=torch.bfloat16 | |
| ).to(device) | |
| # Set scheduler | |
| pipe.scheduler = AdditFlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config) | |
| print("Model initialized successfully!") | |
| except Exception as e: | |
| print(f"Error initializing model: {str(e)}") | |
| print("The application will start but model functionality will be unavailable.") | |
| def validate_inputs(prompt_source, prompt_target, subject_token): | |
| """Validate user inputs""" | |
| if not prompt_source.strip(): | |
| return "Source prompt cannot be empty" | |
| if not prompt_target.strip(): | |
| return "Target prompt cannot be empty" | |
| if not subject_token.strip(): | |
| return "Subject token cannot be empty" | |
| if subject_token not in prompt_target: | |
| return f"Subject token '{subject_token}' must appear in the target prompt" | |
| return None | |
| def resize_and_crop_image(image): | |
| """ | |
| Resize and center crop image to 1024x1024. | |
| Returns the processed image, a message about what was done, and original size info. | |
| """ | |
| if image is None: | |
| return None, "", None | |
| original_width, original_height = image.size | |
| original_size = (original_width, original_height) | |
| # If already 1024x1024, no processing needed | |
| if original_width == 1024 and original_height == 1024: | |
| return image, "", original_size | |
| # Calculate scaling to make smaller dimension 1024 | |
| scale = 1024 / min(original_width, original_height) | |
| new_width = int(original_width * scale) | |
| new_height = int(original_height * scale) | |
| # Resize image | |
| resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| # Center crop to 1024x1024 | |
| left = (new_width - 1024) // 2 | |
| top = (new_height - 1024) // 2 | |
| right = left + 1024 | |
| bottom = top + 1024 | |
| cropped_image = resized_image.crop((left, top, right, bottom)) | |
| # Create status message | |
| if new_width == 1024 and new_height == 1024: | |
| message = f"<div style='background-color: #e8f5e8; border: 1px solid #4caf50; border-radius: 5px; padding: 8px; margin-bottom: 10px;'><span style='color: #2e7d32; font-weight: bold;'>✅ Image resized to 1024×1024</span></div>" | |
| else: | |
| message = f"<div style='background-color: #e8f5e8; border: 1px solid #4caf50; border-radius: 5px; padding: 8px; margin-bottom: 10px;'><span style='color: #2e7d32; font-weight: bold;'>✅ Image resized and center cropped to 1024×1024</span></div>" | |
| return cropped_image, message, original_size | |
| def handle_image_upload(image): | |
| """Handle image upload and store original size globally""" | |
| global original_image_size | |
| if image is None: | |
| original_image_size = None | |
| return None, "" | |
| # Store original size | |
| original_image_size = image.size | |
| # Process image | |
| processed_image, message, _ = resize_and_crop_image(image) | |
| return processed_image, message | |
| def process_generated_image( | |
| prompt_source, | |
| prompt_target, | |
| subject_token, | |
| seed_src, | |
| seed_obj, | |
| extended_scale, | |
| structure_transfer_step, | |
| blend_steps, | |
| localization_model, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """Process generated image with ADDIT""" | |
| global pipe | |
| if pipe is None: | |
| return None, None, "Model not initialized. Please restart the application." | |
| # Validate inputs | |
| error_msg = validate_inputs(prompt_source, prompt_target, subject_token) | |
| if error_msg: | |
| return None, None, error_msg | |
| # Print current time and input information | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| print(f"\n[{current_time}] Starting Generated Image Processing") | |
| print(f"Source Prompt: '{prompt_source}'") | |
| print(f"Target Prompt: '{prompt_target}'") | |
| print(f"Subject Token: '{subject_token}'") | |
| print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") | |
| print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") | |
| print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") | |
| try: | |
| # Parse blend steps | |
| if blend_steps.strip(): | |
| blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] | |
| else: | |
| blend_steps_list = [] | |
| # Generate images | |
| src_image, edited_image = add_object_generated( | |
| pipe=pipe, | |
| prompt_source=prompt_source, | |
| prompt_object=prompt_target, | |
| subject_token=subject_token, | |
| seed_src=seed_src, | |
| seed_obj=seed_obj, | |
| show_attention=False, | |
| extended_scale=extended_scale, | |
| structure_transfer_step=structure_transfer_step, | |
| blend_steps=blend_steps_list, | |
| localization_model=localization_model, | |
| display_output=False | |
| ) | |
| return src_image, edited_image, "Images generated successfully!" | |
| except Exception as e: | |
| error_msg = f"Error generating images: {str(e)}" | |
| print(error_msg) | |
| return None, None, error_msg | |
| def process_real_image( | |
| source_image, | |
| prompt_source, | |
| prompt_target, | |
| subject_token, | |
| seed_src, | |
| seed_obj, | |
| extended_scale, | |
| structure_transfer_step, | |
| blend_steps, | |
| localization_model, | |
| use_offset, | |
| disable_inversion, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| """Process real image with ADDIT""" | |
| global pipe | |
| if pipe is None: | |
| return None, None, "Model not initialized. Please restart the application." | |
| if source_image is None: | |
| return None, None, "Please upload a source image" | |
| # Validate inputs | |
| error_msg = validate_inputs(prompt_source, prompt_target, subject_token) | |
| if error_msg: | |
| return None, None, error_msg | |
| # Print current time and input information | |
| current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| print(f"\n[{current_time}] Starting Real Image Processing") | |
| if original_image_size: | |
| print(f"Original uploaded image size: {original_image_size[0]}×{original_image_size[1]}") | |
| print(f"Source Image Size: {source_image.size}") | |
| print(f"Source Prompt: '{prompt_source}'") | |
| print(f"Target Prompt: '{prompt_target}'") | |
| print(f"Subject Token: '{subject_token}'") | |
| print(f"Source Seed: {seed_src}, Object Seed: {seed_obj}") | |
| print(f"Extended Scale: {extended_scale}, Structure Transfer Step: {structure_transfer_step}") | |
| print(f"Blend Steps: '{blend_steps}', Localization Model: '{localization_model}'") | |
| print(f"Use Offset: {use_offset}, Disable Inversion: {disable_inversion}") | |
| try: | |
| # Resize source image | |
| source_image = source_image.resize((1024, 1024)) | |
| # Parse blend steps | |
| if blend_steps.strip(): | |
| blend_steps_list = [int(x.strip()) for x in blend_steps.split(',') if x.strip()] | |
| else: | |
| blend_steps_list = [] | |
| # Process image | |
| src_image, edited_image = add_object_real( | |
| pipe=pipe, | |
| source_image=source_image, | |
| prompt_source=prompt_source, | |
| prompt_object=prompt_target, | |
| subject_token=subject_token, | |
| seed_src=seed_src, | |
| seed_obj=seed_obj, | |
| extended_scale=extended_scale, | |
| structure_transfer_step=structure_transfer_step, | |
| blend_steps=blend_steps_list, | |
| localization_model=localization_model, | |
| use_offset=use_offset, | |
| show_attention=False, | |
| use_inversion=not disable_inversion, | |
| display_output=False | |
| ) | |
| return src_image, edited_image, "Image edited successfully!" | |
| except Exception as e: | |
| error_msg = f"Error processing image: {str(e)}" | |
| print(error_msg) | |
| return None, None, error_msg | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| # Show model status in the interface | |
| model_status = "Model ready!" if pipe is not None else "Model initialization failed - functionality unavailable" | |
| with gr.Blocks(title="🎨 Add-it: Training-Free Object Insertion in Images With Pretrained Diffusion Models", theme=gr.themes.Soft()) as demo: | |
| gr.HTML(f""" | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h1>🎨 Add-it: Training-Free Object Insertion</h1> | |
| <p>Add objects to images using pretrained diffusion models</p> | |
| <p><a href="https://research.nvidia.com/labs/par/addit/" target="_blank">🌐 Project Website</a> | | |
| <a href="https://arxiv.org/abs/2411.07232" target="_blank">📄 Paper</a> | | |
| <a href="https://github.com/NVlabs/addit" target="_blank">💻 Code</a></p> | |
| <p style="color: {'green' if pipe is not None else 'red'}; font-weight: bold;">Status: {model_status}</p> | |
| </div> | |
| """) | |
| # Main interface | |
| with gr.Tabs(): | |
| # Generated Images Tab | |
| with gr.TabItem("🎭 Generated Images"): | |
| gr.Markdown("### Generate a base image and add objects to it") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gen_prompt_source = gr.Textbox( | |
| label="Source Prompt", | |
| placeholder="A photo of a cat sitting on the couch", | |
| value="A photo of a cat sitting on the couch" | |
| ) | |
| gen_prompt_target = gr.Textbox( | |
| label="Target Prompt", | |
| placeholder="A photo of a cat wearing a blue hat sitting on the couch", | |
| value="A photo of a cat wearing a blue hat sitting on the couch" | |
| ) | |
| gen_subject_token = gr.Textbox( | |
| label="Subject Token", | |
| placeholder="hat", | |
| value="hat", | |
| info="Single token representing the object to add **(must appear in target prompt)**" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| gen_seed_src = gr.Number(label="Source Seed", value=1, precision=0) | |
| gen_seed_obj = gr.Number(label="Object Seed", value=42, precision=0) | |
| gen_extended_scale = gr.Slider( | |
| label="Extended Scale", | |
| minimum=1.0, | |
| maximum=1.3, | |
| value=1.05, | |
| step=0.01 | |
| ) | |
| gen_structure_transfer_step = gr.Slider( | |
| label="Structure Transfer Step", | |
| minimum=0, | |
| maximum=10, | |
| value=2, | |
| step=1 | |
| ) | |
| gen_blend_steps = gr.Textbox( | |
| label="Blend Steps", | |
| value="15", | |
| info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" | |
| ) | |
| gen_localization_model = gr.Dropdown( | |
| label="Localization Model", | |
| choices=[ | |
| "attention_points_sam", | |
| "attention", | |
| "attention_box_sam", | |
| "attention_mask_sam", | |
| "grounding_sam" | |
| ], | |
| value="attention_points_sam" | |
| ) | |
| gen_submit_btn = gr.Button("🎨 Generate & Edit", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| gen_src_output = gr.Image(label="Generated Source Image", type="pil") | |
| gen_edited_output = gr.Image(label="Edited Image", type="pil") | |
| gen_status = gr.Textbox(label="Status", interactive=False) | |
| gen_submit_btn.click( | |
| fn=process_generated_image, | |
| inputs=[ | |
| gen_prompt_source, gen_prompt_target, gen_subject_token, | |
| gen_seed_src, gen_seed_obj, gen_extended_scale, | |
| gen_structure_transfer_step, gen_blend_steps, | |
| gen_localization_model | |
| ], | |
| outputs=[gen_src_output, gen_edited_output, gen_status] | |
| ) | |
| # Examples for generated images | |
| gr.Examples( | |
| examples=[ | |
| ["An empty throne", "A king sitting on a throne", "king"], | |
| ["A photo of a man sitting on a bench", "A photo of a man sitting on a bench with a dog", "dog"], | |
| ["A photo of a cat sitting on the couch", "A photo of a cat wearing a blue hat sitting on the couch", "hat"], | |
| ["A car driving through an empty street", "A pink car driving through an empty street", "car"] | |
| ], | |
| inputs=[ | |
| gen_prompt_source, gen_prompt_target, gen_subject_token | |
| ], | |
| label="Example Prompts" | |
| ) | |
| # Real Images Tab | |
| with gr.TabItem("📸 Real Images"): | |
| gr.Markdown("### Upload an image and add objects to it") | |
| gr.HTML("<p style='color: orange; font-weight: bold; margin: -15px -10px;'>Note: Images will be automatically resized and center cropped to 1024×1024 pixels.</p>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| real_image_status = gr.HTML(visible=False) | |
| real_source_image = gr.Image(label="Source Image", type="pil") | |
| real_prompt_source = gr.Textbox( | |
| label="Source Prompt", | |
| placeholder="A photo of a bed in a dark room", | |
| value="A photo of a bed in a dark room" | |
| ) | |
| real_prompt_target = gr.Textbox( | |
| label="Target Prompt", | |
| placeholder="A photo of a dog lying on a bed in a dark room", | |
| value="A photo of a dog lying on a bed in a dark room" | |
| ) | |
| real_subject_token = gr.Textbox( | |
| label="Subject Token", | |
| placeholder="dog", | |
| value="dog", | |
| info="Single token representing the object to add **(must appear in target prompt)**" | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| real_seed_src = gr.Number(label="Source Seed", value=1, precision=0) | |
| real_seed_obj = gr.Number(label="Object Seed", value=0, precision=0) | |
| real_extended_scale = gr.Slider( | |
| label="Extended Scale", | |
| minimum=1.0, | |
| maximum=1.3, | |
| value=1.1, | |
| step=0.01 | |
| ) | |
| real_structure_transfer_step = gr.Slider( | |
| label="Structure Transfer Step", | |
| minimum=0, | |
| maximum=10, | |
| value=4, | |
| step=1 | |
| ) | |
| real_blend_steps = gr.Textbox( | |
| label="Blend Steps", | |
| value="18", | |
| info="Comma-separated list of steps (e.g., '15,20') or empty for no blending" | |
| ) | |
| real_localization_model = gr.Dropdown( | |
| label="Localization Model", | |
| choices=[ | |
| "attention", | |
| "attention_points_sam", | |
| "attention_box_sam", | |
| "attention_mask_sam", | |
| "grounding_sam" | |
| ], | |
| value="attention" | |
| ) | |
| real_use_offset = gr.Checkbox(label="Use Offset", value=False) | |
| real_disable_inversion = gr.Checkbox(label="Disable Inversion", value=False) | |
| real_submit_btn = gr.Button("🎨 Edit Image", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| real_src_output = gr.Image(label="Source Image", type="pil") | |
| real_edited_output = gr.Image(label="Edited Image", type="pil") | |
| real_status = gr.Textbox(label="Status", interactive=False) | |
| # Handle image upload and preprocessing | |
| real_source_image.upload( | |
| fn=handle_image_upload, | |
| inputs=[real_source_image], | |
| outputs=[real_source_image, real_image_status] | |
| ).then( | |
| fn=lambda status: gr.update(visible=bool(status.strip()), value=status), | |
| inputs=[real_image_status], | |
| outputs=[real_image_status] | |
| ) | |
| real_submit_btn.click( | |
| fn=process_real_image, | |
| inputs=[ | |
| real_source_image, real_prompt_source, real_prompt_target, real_subject_token, | |
| real_seed_src, real_seed_obj, real_extended_scale, | |
| real_structure_transfer_step, real_blend_steps, | |
| real_localization_model, real_use_offset, | |
| real_disable_inversion | |
| ], | |
| outputs=[real_src_output, real_edited_output, real_status] | |
| ) | |
| # Examples for real images | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "images/bed_dark_room.jpg", | |
| "A photo of a bed in a dark room", | |
| "A photo of a dog lying on a bed in a dark room", | |
| "dog" | |
| ], | |
| [ | |
| "images/flower.jpg", | |
| "A photo of a flower", | |
| "A bee standing on a flower", | |
| "bee" | |
| ] | |
| ], | |
| inputs=[ | |
| real_source_image, real_prompt_source, real_prompt_target, real_subject_token | |
| ], | |
| label="Example Images & Prompts" | |
| ) | |
| # Tips | |
| with gr.Accordion("💡 Tips for Better Results", open=False): | |
| gr.Markdown(""" | |
| - **Prompt Design**: The Target Prompt should be similar to the Source Prompt, but include a description of the new object to insert | |
| - **Seed Variation**: Try different values for Object Seed - some prompts may require a few attempts to get satisfying results | |
| - **Localization Models**: The most effective options are `attention_points_sam` and `attention`. Use Show Attention to visualize localization performance | |
| - **Object Placement Issues**: If the object is not added to the image: | |
| - Try **decreasing** Structure Transfer Step | |
| - Try **increasing** Extended Scale | |
| - **Flexibility**: To allow more flexibility in modifying the source image, leave Blend Steps empty to send an empty list | |
| """) | |
| return demo | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |