Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,6 +17,8 @@ model_id = "Arrcttacsrks/netrunner-exe_Insight-Swap-models-onnx"
|
|
| 17 |
model_file = hf_hub_download(repo_id=model_id, filename="simswap_512_unoff.onnx", token=huggingface_token)
|
| 18 |
|
| 19 |
def load_and_preprocess_image(image):
|
|
|
|
|
|
|
| 20 |
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 21 |
img = cv2.resize(img, (512, 512))
|
| 22 |
img = img / 255.0 # Normalize to [0, 1]
|
|
@@ -26,24 +28,37 @@ def swap_faces(source_image, target_image):
|
|
| 26 |
# Load the ONNX model
|
| 27 |
session = ort.InferenceSession(model_file)
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# Preprocess the images
|
| 30 |
source_img = load_and_preprocess_image(source_image)
|
| 31 |
target_img = load_and_preprocess_image(target_image)
|
| 32 |
|
| 33 |
# Prepare input data for the model
|
| 34 |
input_data = np.array([source_img, target_img]).transpose(0, 3, 1, 2)
|
| 35 |
-
|
| 36 |
-
#
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Post-process the result
|
| 40 |
return np.clip(result * 255, 0, 255).astype(np.uint8)
|
| 41 |
|
| 42 |
# Create Gradio interface
|
| 43 |
-
interface = gr.Interface(
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
model_file = hf_hub_download(repo_id=model_id, filename="simswap_512_unoff.onnx", token=huggingface_token)
|
| 18 |
|
| 19 |
def load_and_preprocess_image(image):
|
| 20 |
+
if image is None or not isinstance(image, np.ndarray):
|
| 21 |
+
raise ValueError("Input image is not valid.")
|
| 22 |
img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 23 |
img = cv2.resize(img, (512, 512))
|
| 24 |
img = img / 255.0 # Normalize to [0, 1]
|
|
|
|
| 28 |
# Load the ONNX model
|
| 29 |
session = ort.InferenceSession(model_file)
|
| 30 |
|
| 31 |
+
# Print model input names for debugging
|
| 32 |
+
input_names = [input.name for input in session.get_inputs()]
|
| 33 |
+
print("Model input names:", input_names)
|
| 34 |
+
|
| 35 |
# Preprocess the images
|
| 36 |
source_img = load_and_preprocess_image(source_image)
|
| 37 |
target_img = load_and_preprocess_image(target_image)
|
| 38 |
|
| 39 |
# Prepare input data for the model
|
| 40 |
input_data = np.array([source_img, target_img]).transpose(0, 3, 1, 2)
|
| 41 |
+
|
| 42 |
+
# Validate input shape
|
| 43 |
+
expected_shape = (1, 3, 512, 512)
|
| 44 |
+
if input_data.shape != expected_shape:
|
| 45 |
+
raise ValueError(f"Input data shape {input_data.shape} does not match expected shape {expected_shape}.")
|
| 46 |
+
|
| 47 |
+
# Run inference using the correct input name
|
| 48 |
+
result = session.run(None, {input_names[0]: input_data})[0] # Use the first input name
|
| 49 |
|
| 50 |
# Post-process the result
|
| 51 |
return np.clip(result * 255, 0, 255).astype(np.uint8)
|
| 52 |
|
| 53 |
# Create Gradio interface
|
| 54 |
+
interface = gr.Interface(
|
| 55 |
+
fn=swap_faces,
|
| 56 |
+
inputs=["image", "image"],
|
| 57 |
+
outputs="image",
|
| 58 |
+
title="Face Swap using SimSwap",
|
| 59 |
+
description="Upload source and target images to swap faces.",
|
| 60 |
+
allow_flagging="never" # Prevent flagging to keep the interface clean
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Launch the interface with share=True for public access
|
| 64 |
+
interface.launch(share=True)
|