Spaces:
Paused
Paused
Julian Bilcke
commited on
Commit
·
22b2a6e
1
Parent(s):
d96ce03
improve concurrency
Browse files- app.py +1 -1
- client/src/hooks/useFaceLandmarkDetection.tsx +1 -1
- engine.py +6 -6
app.py
CHANGED
|
@@ -79,7 +79,7 @@ async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
|
|
| 79 |
|
| 80 |
elif msg.type == WSMsgType.TEXT:
|
| 81 |
data = json.loads(msg.data)
|
| 82 |
-
webp_bytes = engine.transform_image(data.get('hash'), data.get('params'))
|
| 83 |
await ws.send_bytes(webp_bytes)
|
| 84 |
|
| 85 |
except Exception as e:
|
|
|
|
| 79 |
|
| 80 |
elif msg.type == WSMsgType.TEXT:
|
| 81 |
data = json.loads(msg.data)
|
| 82 |
+
webp_bytes = await engine.transform_image(data.get('hash'), data.get('params'))
|
| 83 |
await ws.send_bytes(webp_bytes)
|
| 84 |
|
| 85 |
except Exception as e:
|
client/src/hooks/useFaceLandmarkDetection.tsx
CHANGED
|
@@ -18,7 +18,7 @@ export function useFaceLandmarkDetection() {
|
|
| 18 |
// if we only send the face/square then we can use 138ms
|
| 19 |
// unfortunately it doesn't work well yet
|
| 20 |
// const throttleInMs = 138ms
|
| 21 |
-
const throttleInMs =
|
| 22 |
////////////////////////////////////////////////////////////////////////
|
| 23 |
|
| 24 |
// State for face detection
|
|
|
|
| 18 |
// if we only send the face/square then we can use 138ms
|
| 19 |
// unfortunately it doesn't work well yet
|
| 20 |
// const throttleInMs = 138ms
|
| 21 |
+
const throttleInMs = 220
|
| 22 |
////////////////////////////////////////////////////////////////////////
|
| 23 |
|
| 24 |
// State for face detection
|
engine.py
CHANGED
|
@@ -129,7 +129,7 @@ class Engine:
|
|
| 129 |
# 'bbox_rot': bbox_info['bbox_rot'].toList(), # 4x2
|
| 130 |
}
|
| 131 |
|
| 132 |
-
def transform_image(self, image_hash: str, params: Dict[str, float]) -> bytes:
|
| 133 |
# If we don't have the image in cache yet, add it
|
| 134 |
if image_hash not in self.processed_cache:
|
| 135 |
raise ValueError("cache miss")
|
|
@@ -197,11 +197,11 @@ class Engine:
|
|
| 197 |
x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t']
|
| 198 |
|
| 199 |
# Apply stitching
|
| 200 |
-
x_d_new = self.live_portrait.live_portrait_wrapper.stitching
|
| 201 |
|
| 202 |
# Generate the output
|
| 203 |
-
out = self.live_portrait.live_portrait_wrapper.warp_decode
|
| 204 |
-
I_p = self.live_portrait.live_portrait_wrapper.parse_output
|
| 205 |
|
| 206 |
buffered = io.BytesIO()
|
| 207 |
|
|
@@ -214,11 +214,11 @@ class Engine:
|
|
| 214 |
# I'm currently running some experiments to do it in the frontend
|
| 215 |
#
|
| 216 |
# --- old way: we do it in the server-side: ---
|
| 217 |
-
mask_ori = prepare_paste_back
|
| 218 |
processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'],
|
| 219 |
dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0])
|
| 220 |
)
|
| 221 |
-
I_p_to_ori_blend = paste_back
|
| 222 |
I_p[0], processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori
|
| 223 |
)
|
| 224 |
result_image = Image.fromarray(I_p_to_ori_blend)
|
|
|
|
| 129 |
# 'bbox_rot': bbox_info['bbox_rot'].toList(), # 4x2
|
| 130 |
}
|
| 131 |
|
| 132 |
+
async def transform_image(self, image_hash: str, params: Dict[str, float]) -> bytes:
|
| 133 |
# If we don't have the image in cache yet, add it
|
| 134 |
if image_hash not in self.processed_cache:
|
| 135 |
raise ValueError("cache miss")
|
|
|
|
| 197 |
x_d_new = processed_data['x_s_info']['scale'] * (x_d_new @ R_new) + processed_data['x_s_info']['t']
|
| 198 |
|
| 199 |
# Apply stitching
|
| 200 |
+
x_d_new = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.stitching, processed_data['x_s'], x_d_new)
|
| 201 |
|
| 202 |
# Generate the output
|
| 203 |
+
out = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.warp_decode, processed_data['f_s'], processed_data['x_s'], x_d_new)
|
| 204 |
+
I_p = await asyncio.to_thread(self.live_portrait.live_portrait_wrapper.parse_output, out['out'])
|
| 205 |
|
| 206 |
buffered = io.BytesIO()
|
| 207 |
|
|
|
|
| 214 |
# I'm currently running some experiments to do it in the frontend
|
| 215 |
#
|
| 216 |
# --- old way: we do it in the server-side: ---
|
| 217 |
+
mask_ori = await asyncio.to_thread(prepare_paste_back,
|
| 218 |
processed_data['inference_cfg'].mask_crop, processed_data['crop_info']['M_c2o'],
|
| 219 |
dsize=(processed_data['img_rgb'].shape[1], processed_data['img_rgb'].shape[0])
|
| 220 |
)
|
| 221 |
+
I_p_to_ori_blend = await asyncio.to_thread(paste_back,
|
| 222 |
I_p[0], processed_data['crop_info']['M_c2o'], processed_data['img_rgb'], mask_ori
|
| 223 |
)
|
| 224 |
result_image = Image.fromarray(I_p_to_ori_blend)
|