Spaces:
Sleeping
Sleeping
| """ | |
| ONNX Runtime inference backend for DeepNAPSI. | |
| Uses the dynamically-quantised INT8 BEiT model for fast CPU inference. | |
| The model is expected at model/model_int8.onnx (committed to the repo via | |
| Git LFS). If that file is absent it is downloaded from the HF Hub. | |
| """ | |
| from __future__ import annotations | |
| import concurrent.futures | |
| import os | |
| from pathlib import Path | |
| from typing import List | |
| import cv2 | |
| import numpy as np | |
| import onnxruntime as ort | |
| from PIL import Image | |
| from nail_detection import get_nails_and_landmarks, draw_hand | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| MODEL_LOCAL = Path(__file__).parent / "model" / "model_int8.onnx" | |
| HF_REPO_ID = os.environ.get("DEEPNAPSI_HF_REPO", "lfolle/DeepNAPSIModel") | |
| HF_FILENAME = "model_int8.onnx" | |
| # BEiT preprocessing parameters (from timm resolve_data_config) | |
| INPUT_SIZE = 384 | |
| MEAN = np.array([0.5, 0.5, 0.5], dtype=np.float32) | |
| STD = np.array([0.5, 0.5, 0.5], dtype=np.float32) | |
| FINGER_NAMES = ["Thumb", "Index", "Middle", "Ring", "Pinky"] | |
| NUM_CLASSES = 5 | |
| NUM_THREADS = int(os.environ.get("ORT_NUM_THREADS", "16")) | |
| # --------------------------------------------------------------------------- | |
| # Model loading | |
| # --------------------------------------------------------------------------- | |
| def _get_model_path() -> Path: | |
| # Env-var override (useful for local dev pointing at hf_space/model/) | |
| env_path = os.environ.get("DEEPNAPSI_MODEL_PATH", "") | |
| if env_path and Path(env_path).exists(): | |
| return Path(env_path) | |
| # Default local path (committed to Space via Git LFS, or pre-downloaded) | |
| if MODEL_LOCAL.exists(): | |
| return MODEL_LOCAL | |
| # Fallback: download from private HF Hub model repo. | |
| # Requires DEEPNAPSI_HF_TOKEN env var (set as a Space secret). | |
| from huggingface_hub import hf_hub_download | |
| token = os.environ.get("DEEPNAPSI_HF_TOKEN") or os.environ.get("DeepNAPSIModel") | |
| if not token: | |
| raise FileNotFoundError( | |
| f"Model not found at {MODEL_LOCAL} and neither DEEPNAPSI_HF_TOKEN nor " | |
| "DeepNAPSIModel secret is set. Set one in the Space settings." | |
| ) | |
| print(f"[backend] Downloading model from private repo {HF_REPO_ID} …") | |
| # Run inside a ThreadPoolExecutor so that huggingface_hub's internal asyncio | |
| # event loop is isolated; avoids the harmless-but-noisy | |
| # "Invalid file descriptor: -1" Python 3.12 GC warning on Space startup. | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: | |
| future = pool.submit(hf_hub_download, HF_REPO_ID, HF_FILENAME, token=token) | |
| path = future.result() | |
| return Path(path) | |
| def _build_session(model_path: Path) -> ort.InferenceSession: | |
| opts = ort.SessionOptions() | |
| opts.intra_op_num_threads = NUM_THREADS | |
| opts.inter_op_num_threads = NUM_THREADS | |
| opts.execution_mode = ort.ExecutionMode.ORT_PARALLEL | |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| return ort.InferenceSession( | |
| str(model_path), | |
| sess_options=opts, | |
| providers=["CPUExecutionProvider"], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Preprocessing (replaces timm transforms, no heavy ML dependency at serve time) | |
| # --------------------------------------------------------------------------- | |
| def _preprocess(nail_rgb: np.ndarray) -> np.ndarray: | |
| """ | |
| Resize → CenterCrop → ToTensor → Normalize, matching BEiT training config. | |
| Returns float32 array [1, 3, 384, 384]. | |
| """ | |
| img = Image.fromarray(nail_rgb).convert("RGB") | |
| # Resize shortest side to INPUT_SIZE with bicubic | |
| w, h = img.size | |
| scale = INPUT_SIZE / min(w, h) | |
| new_w, new_h = max(INPUT_SIZE, round(w * scale)), max(INPUT_SIZE, round(h * scale)) | |
| img = img.resize((new_w, new_h), Image.BICUBIC) | |
| # CenterCrop | |
| left = (new_w - INPUT_SIZE) // 2 | |
| top = (new_h - INPUT_SIZE) // 2 | |
| img = img.crop((left, top, left + INPUT_SIZE, top + INPUT_SIZE)) | |
| # To float [0,1], normalise | |
| arr = np.asarray(img, dtype=np.float32) / 255.0 | |
| arr = (arr - MEAN) / STD | |
| return arr.transpose(2, 0, 1)[None] # [1, C, H, W] | |
| # --------------------------------------------------------------------------- | |
| # Inference with 3-view TTA | |
| # --------------------------------------------------------------------------- | |
| def _tta_logits(session: ort.InferenceSession, pixel_values: np.ndarray) -> np.ndarray: | |
| """Average logits over original + hflip + vflip views.""" | |
| views = [ | |
| pixel_values, | |
| pixel_values[:, :, :, ::-1].copy(), # horizontal flip | |
| pixel_values[:, :, ::-1, :].copy(), # vertical flip | |
| ] | |
| logits = np.stack( | |
| [session.run(None, {"pixel_values": v})[0] for v in views] | |
| ).mean(axis=0) | |
| return logits # [B, 5] | |
| # --------------------------------------------------------------------------- | |
| # Top-level backend class | |
| # --------------------------------------------------------------------------- | |
| class Backend: | |
| def __init__(self) -> None: | |
| model_path = _get_model_path() | |
| self._session = _build_session(model_path) | |
| print(f"[backend] Loaded model from {model_path} with {NUM_THREADS} ORT threads.") | |
| def predict(self, image_rgb: np.ndarray) -> dict: | |
| """ | |
| Run the full DeepNAPSI pipeline on a hand image. | |
| Args: | |
| image_rgb: HxWx3 uint8 RGB array (Gradio default). | |
| Returns: | |
| dict with keys: | |
| annotated_image – RGB image with hand skeleton drawn | |
| nails – list of 5 RGB nail crop arrays | |
| napsi_scores – list of 5 int NAPSI predictions (0-4) | |
| napsi_sum – int, sum of all 5 scores | |
| error – str | None | |
| """ | |
| image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) | |
| nails, landmarks = get_nails_and_landmarks(image_bgr) | |
| if nails is None or landmarks is None: | |
| return { | |
| "annotated_image": image_rgb, | |
| "nails": [np.zeros((64, 64, 3), dtype=np.uint8)] * 5, | |
| "napsi_scores": [-1] * 5, | |
| "napsi_sum": -1, | |
| "error": "No hand detected. Please upload a clear photo of one hand.", | |
| } | |
| # Draw skeleton on a copy | |
| annotated = image_bgr.copy() | |
| draw_hand(annotated, landmarks) | |
| annotated_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) | |
| # Run classification on all 5 nails | |
| napsi_scores: List[int] = [] | |
| nail_rgbs: List[np.ndarray] = [] | |
| for nail_bgr in nails: | |
| # Nail crops come out as RGB from extract_nails (already BGR→RGB swapped inside) | |
| nail_rgb = nail_bgr # already RGB after the ::-1 flip in extract_nails | |
| nail_rgbs.append(nail_rgb) | |
| pixel_values = _preprocess(nail_rgb) | |
| logits = _tta_logits(self._session, pixel_values) # [1, 5] | |
| pred = int(np.argmax(logits, axis=-1)[0]) | |
| napsi_scores.append(pred) | |
| return { | |
| "annotated_image": annotated_rgb, | |
| "nails": nail_rgbs, | |
| "napsi_scores": napsi_scores, | |
| "napsi_sum": sum(napsi_scores), | |
| "error": None, | |
| } | |