Arrcttacsrks commited on
Commit
a044f17
·
verified ·
1 Parent(s): dad1d2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -10
app.py CHANGED
@@ -17,6 +17,8 @@ model_id = "Arrcttacsrks/netrunner-exe_Insight-Swap-models-onnx"
17
  model_file = hf_hub_download(repo_id=model_id, filename="simswap_512_unoff.onnx", token=huggingface_token)
18
 
19
  def load_and_preprocess_image(image):
 
 
20
  img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
21
  img = cv2.resize(img, (512, 512))
22
  img = img / 255.0 # Normalize to [0, 1]
@@ -26,24 +28,37 @@ def swap_faces(source_image, target_image):
26
  # Load the ONNX model
27
  session = ort.InferenceSession(model_file)
28
 
 
 
 
 
29
  # Preprocess the images
30
  source_img = load_and_preprocess_image(source_image)
31
  target_img = load_and_preprocess_image(target_image)
32
 
33
  # Prepare input data for the model
34
  input_data = np.array([source_img, target_img]).transpose(0, 3, 1, 2)
35
-
36
- # Run inference
37
- result = session.run(None, {'input': input_data})[0]
 
 
 
 
 
38
 
39
  # Post-process the result
40
  return np.clip(result * 255, 0, 255).astype(np.uint8)
41
 
42
  # Create Gradio interface
43
- interface = gr.Interface(fn=swap_faces,
44
- inputs=["image", "image"],
45
- outputs="image",
46
- title="Face Swap using SimSwap",
47
- description="Upload source and target images to swap faces.")
48
- # Launch the interface
49
- interface.launch()
 
 
 
 
 
17
  model_file = hf_hub_download(repo_id=model_id, filename="simswap_512_unoff.onnx", token=huggingface_token)
18
 
19
  def load_and_preprocess_image(image):
20
+ if image is None or not isinstance(image, np.ndarray):
21
+ raise ValueError("Input image is not valid.")
22
  img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
23
  img = cv2.resize(img, (512, 512))
24
  img = img / 255.0 # Normalize to [0, 1]
 
28
  # Load the ONNX model
29
  session = ort.InferenceSession(model_file)
30
 
31
+ # Print model input names for debugging
32
+ input_names = [input.name for input in session.get_inputs()]
33
+ print("Model input names:", input_names)
34
+
35
  # Preprocess the images
36
  source_img = load_and_preprocess_image(source_image)
37
  target_img = load_and_preprocess_image(target_image)
38
 
39
  # Prepare input data for the model
40
  input_data = np.array([source_img, target_img]).transpose(0, 3, 1, 2)
41
+
42
+ # Validate input shape
43
+ expected_shape = (1, 3, 512, 512)
44
+ if input_data.shape != expected_shape:
45
+ raise ValueError(f"Input data shape {input_data.shape} does not match expected shape {expected_shape}.")
46
+
47
+ # Run inference using the correct input name
48
+ result = session.run(None, {input_names[0]: input_data})[0] # Use the first input name
49
 
50
  # Post-process the result
51
  return np.clip(result * 255, 0, 255).astype(np.uint8)
52
 
53
  # Create Gradio interface
54
+ interface = gr.Interface(
55
+ fn=swap_faces,
56
+ inputs=["image", "image"],
57
+ outputs="image",
58
+ title="Face Swap using SimSwap",
59
+ description="Upload source and target images to swap faces.",
60
+ allow_flagging="never" # Prevent flagging to keep the interface clean
61
+ )
62
+
63
+ # Launch the interface with share=True for public access
64
+ interface.launch(share=True)