SimSwap / app.py
Arrcttacsrks's picture
Update app.py
6840031 verified
import os
import cv2
import numpy as np
import onnxruntime as ort
import gradio as gr
from huggingface_hub import hf_hub_download
# Ensure the Hugging Face token is retrieved from environment variables
huggingface_token = os.getenv("HF_TOKEN")
if huggingface_token is None:
raise ValueError("HUGGINGFACE_TOKEN environment variable not set.")
# Download the model file from Hugging Face using the token
model_id = "Arrcttacsrks/netrunner-exe_Insight-Swap-models-onnx"
model_file = hf_hub_download(repo_id=model_id, filename="simswap_512_unoff.onnx", token=huggingface_token)
def load_and_preprocess_image(image):
if image is None or not isinstance(image, np.ndarray):
raise ValueError("Input image is not valid.")
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
img = cv2.resize(img, (512, 512))
img = img / 255.0 # Normalize to [0, 1]
return img.astype(np.float32)
def create_identity_features(size=512):
"""Create identity features of the expected shape [1, 512]"""
return np.ones((1, size), dtype=np.float32)
def swap_faces(source_image, target_image):
try:
# Load the ONNX model
session = ort.InferenceSession(model_file)
# Get input names
input_names = [input.name for input in session.get_inputs()]
# Print input shapes for debugging
for input in session.get_inputs():
print(f"Input '{input.name}' expects shape: {input.shape}")
# Preprocess the source image
source_img = load_and_preprocess_image(source_image)
# Prepare inputs with correct shapes
# First input: source image with shape [1, 3, 512, 512]
source_input = source_img.transpose(2, 0, 1)[np.newaxis, ...]
# Second input: identity features with shape [1, 512]
identity_features = create_identity_features()
# Create input dictionary
input_dict = {
input_names[0]: source_input.astype(np.float32),
input_names[1]: identity_features
}
# Run inference
result = session.run(None, input_dict)[0]
# Post-process the result
result = result[0].transpose(1, 2, 0) # Convert from NCHW to HWC format
result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) # Convert back to RGB
return np.clip(result * 255, 0, 255).astype(np.uint8)
except Exception as e:
print(f"Error during face swapping: {str(e)}")
raise
# Create Gradio interface
interface = gr.Interface(
fn=swap_faces,
inputs=[
gr.Image(label="Source Face", type="numpy"),
gr.Image(label="Target Image", type="numpy")
],
outputs=gr.Image(label="Result"),
title="Face Swap using SimSwap",
description="Upload a source face and a target image to swap faces. The source face will be transferred onto the target image.",
flagging_mode="never"
)
# Launch the interface
if __name__ == "__main__":
interface.launch(share=True)