File size: 3,058 Bytes
eb27e3e
905b294
 
 
 
935ef2f
 
eb27e3e
dad1d2b
eb27e3e
 
 
 
935ef2f
eb27e3e
905b294
 
a044f17
 
b3f79da
905b294
 
935ef2f
905b294
 
 
0931b5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3f79da
0931b5b
 
 
905b294
935ef2f
a044f17
 
b3f79da
0931b5b
 
b3f79da
 
a044f17
b3f79da
0931b5b
a044f17
 
0931b5b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import cv2
import numpy as np
import onnxruntime as ort
import gradio as gr
from huggingface_hub import hf_hub_download

# Ensure the Hugging Face token is retrieved from environment variables
huggingface_token = os.getenv("HF_TOKEN")
if huggingface_token is None:
    raise ValueError("HUGGINGFACE_TOKEN environment variable not set.")

# Download the model file from Hugging Face using the token
model_id = "Arrcttacsrks/netrunner-exe_Insight-Swap-models-onnx"
model_file = hf_hub_download(repo_id=model_id, filename="simswap_512_unoff.onnx", token=huggingface_token)

def load_and_preprocess_image(image):
    if image is None or not isinstance(image, np.ndarray):
        raise ValueError("Input image is not valid.")
    
    img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    img = cv2.resize(img, (512, 512))
    img = img / 255.0  # Normalize to [0, 1]
    return img.astype(np.float32)

def swap_faces(source_image, target_image):
    try:
        # Load the ONNX model
        session = ort.InferenceSession(model_file)
        
        # Get input names
        input_names = [input.name for input in session.get_inputs()]
        
        # Print input shapes for debugging
        for input in session.get_inputs():
            print(f"Input '{input.name}' expects shape: {input.shape}")
        
        # Preprocess the images
        source_img = load_and_preprocess_image(source_image)
        target_img = load_and_preprocess_image(target_image)
        
        # Reshape inputs according to model requirements
        # For the first input (assuming it's the image input)
        source_input = source_img.transpose(2, 0, 1)[np.newaxis, ...]  # Shape: (1, 3, 512, 512)
        
        # For the second input (onnx::Gemm_1), reshape to rank 2 as required by the error message
        target_features = target_img.transpose(2, 0, 1).reshape(-1, 512)  # Reshape to 2D array
        
        # Create input dictionary
        input_dict = {
            input_names[0]: source_input.astype(np.float32),
            input_names[1]: target_features.astype(np.float32)
        }
        
        # Run inference
        result = session.run(None, input_dict)[0]
        
        # Post-process the result
        result = result[0].transpose(1, 2, 0)  # Convert from NCHW to HWC format
        result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)  # Convert back to RGB
        return np.clip(result * 255, 0, 255).astype(np.uint8)
    
    except Exception as e:
        print(f"Error during face swapping: {str(e)}")
        raise

# Create Gradio interface
interface = gr.Interface(
    fn=swap_faces,
    inputs=[
        gr.Image(label="Source Face", type="numpy"),
        gr.Image(label="Target Image", type="numpy")
    ],
    outputs=gr.Image(label="Result"),
    title="Face Swap using SimSwap",
    description="Upload a source face and a target image to swap faces. The source face will be transferred onto the target image.",
    allow_flagging="never"
)

# Launch the interface
if __name__ == "__main__":
    interface.launch()