Spaces:
Running
Running
| """ | |
| REST API сервер для сегментации изображений через SAM2. | |
| Уставший сеньор кодит это в 3 часа ночи, поэтому код местами будет грязный. | |
| """ | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Body | |
| from fastapi.responses import JSONResponse, HTMLResponse, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import io | |
| import os | |
| import base64 | |
| import cv2 | |
| from typing import List, Dict, Any, Optional, Literal | |
| import logging | |
| from datetime import datetime | |
| import json | |
| # Настройка логирования, потому что дебажить это говно иначе невозможно | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Глобальные переменные для модели (лень каждый раз загружать) | |
| predictor = None | |
| device = None | |
| # ===== Pydantic модели для батчинг API ===== | |
| class BBoxModel(BaseModel): | |
| """Bounding box в нормализованных координатах (0.0 - 1.0) или пиксельных""" | |
| x_min: float = Field(..., description="X координата левого верхнего угла") | |
| y_min: float = Field(..., description="Y координата левого верхнего угла") | |
| x_max: float = Field(..., description="X координата правого нижнего угла") | |
| y_max: float = Field(..., description="Y координата правого нижнего угла") | |
| class PromptModel(BaseModel): | |
| """Промпт для сегментации одного объекта""" | |
| id: int = Field(..., description="Уникальный ID объекта") | |
| type: Literal["mask", "box", "points"] = Field(..., description="Тип промпта") | |
| data: str = Field(..., description="Данные промпта (base64 для mask, JSON для points)") | |
| bbox: Optional[BBoxModel] = Field(None, description="Опциональный bounding box") | |
| label: Optional[str] = Field(None, description="Метка объекта (person, car, etc)") | |
| selected: bool = Field(True, description="Обрабатывать ли этот промпт") | |
| class SegmentOptionsModel(BaseModel): | |
| """Опции сегментации""" | |
| extract_objects: bool = Field(True, description="Вернуть вырезанные объекты") | |
| include_masks: bool = Field(False, description="Включить контуры масок") | |
| clean_masks: bool = Field(True, description="Очистить маски от артефактов") | |
| class BatchSegmentRequest(BaseModel): | |
| """Запрос на батчинг сегментацию""" | |
| image: str = Field(..., description="Изображение в base64 (с data URL или без)") | |
| prompts: List[PromptModel] = Field(..., description="Массив промптов") | |
| options: Optional[SegmentOptionsModel] = Field(default_factory=SegmentOptionsModel) | |
| class SegmentResultModel(BaseModel): | |
| """Результат сегментации одного объекта""" | |
| id: int | |
| label: Optional[str] = None | |
| bbox: Dict[str, Any] | |
| area: int | |
| center: Dict[str, int] | |
| confidence: float | |
| extracted_image: Optional[str] = None | |
| contours: Optional[List[Dict[str, Any]]] = None | |
| mask_rle: Optional[Dict[str, Any]] = None | |
| class BatchSegmentResponse(BaseModel): | |
| """Ответ батчинг сегментации""" | |
| success: bool | |
| image_size: Dict[str, int] | |
| results: List[SegmentResultModel] | |
| def save_batch_request_log(request_data: dict, response_data: dict, image_width: int, image_height: int): | |
| """ | |
| Сохраняет запрос батчинга для аудита и дебага. | |
| Создает папку с timestamp и сохраняет только метаданные: | |
| 1. Лог запроса (request.json) - параметры без base64 | |
| 2. Лог ответа (response.json) - результаты без base64 | |
| 3. Краткую сводку (summary.json) | |
| ⚠️ Изображения и маски НЕ сохраняются для безопасности! | |
| """ | |
| try: | |
| # Создаем корневую папку для логов | |
| logs_dir = "batch_logs" | |
| os.makedirs(logs_dir, exist_ok=True) | |
| # Создаем папку с timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Миллисекунды | |
| request_dir = os.path.join(logs_dir, timestamp) | |
| os.makedirs(request_dir, exist_ok=True) | |
| logger.info(f"📁 Сохраняю лог запроса в: {request_dir}") | |
| # Сохраняем запрос (без base64 для безопасности) | |
| request_log = { | |
| "timestamp": timestamp, | |
| "image_size": { | |
| "width": image_width, | |
| "height": image_height | |
| }, | |
| "prompts": [ | |
| { | |
| "id": p.get("id"), | |
| "type": p.get("type"), | |
| "label": p.get("label"), | |
| "bbox": p.get("bbox"), | |
| "selected": p.get("selected"), | |
| "data_length": len(p.get("data", "")) # Длина вместо самих данных | |
| } | |
| for p in request_data.get("prompts", []) | |
| ], | |
| "options": request_data.get("options", {}) | |
| } | |
| request_path = os.path.join(request_dir, "request.json") | |
| with open(request_path, "w", encoding="utf-8") as f: | |
| json.dump(request_log, f, indent=2, ensure_ascii=False) | |
| logger.info(f" ✓ Сохранен лог запроса: {request_path}") | |
| # 4. Сохраняем ответ (без base64 объектов) | |
| response_log = { | |
| "timestamp": timestamp, | |
| "success": response_data.get("success"), | |
| "image_size": response_data.get("image_size"), | |
| "results": [ | |
| { | |
| "id": r.get("id"), | |
| "label": r.get("label"), | |
| "bbox": r.get("bbox"), | |
| "area": r.get("area"), | |
| "center": r.get("center"), | |
| "confidence": r.get("confidence"), | |
| "has_extracted_image": "extracted_image" in r, | |
| "has_contours": "contours" in r | |
| } | |
| for r in response_data.get("results", []) | |
| ] | |
| } | |
| response_path = os.path.join(request_dir, "response.json") | |
| with open(response_path, "w", encoding="utf-8") as f: | |
| json.dump(response_log, f, indent=2, ensure_ascii=False) | |
| logger.info(f" ✓ Сохранен лог ответа: {response_path}") | |
| # 3. Создаем summary файл | |
| summary = { | |
| "timestamp": timestamp, | |
| "processed_prompts": len(response_data.get("results", [])), | |
| "total_prompts": len(request_data.get("prompts", [])), | |
| "selected_prompts": len([p for p in request_data.get("prompts", []) if p.get("selected", True)]), | |
| "image_size": f"{image_width}x{image_height}", | |
| "prompt_types": [p.get("type") for p in request_data.get("prompts", [])], | |
| "files": { | |
| "request": "request.json", | |
| "response": "response.json" | |
| } | |
| } | |
| summary_path = os.path.join(request_dir, "summary.json") | |
| with open(summary_path, "w", encoding="utf-8") as f: | |
| json.dump(summary, f, indent=2, ensure_ascii=False) | |
| logger.info(f"✅ Лог запроса сохранен: {request_dir}") | |
| except Exception as e: | |
| logger.error(f"❌ Ошибка при сохранении лога: {e}") | |
| # Не прерываем обработку запроса если не удалось сохранить лог | |
| def load_model(checkpoint_path: str = "checkpoints/sam2.1_hiera_tiny.pt"): | |
| """ | |
| Загружает модель SAM2. | |
| Вызывается один раз при старте сервера. | |
| """ | |
| global predictor, device | |
| try: | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| # Проверяем CUDA | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Используем устройство: {device}") | |
| if device == "cpu": | |
| logger.warning("CUDA недоступна, работаем на CPU (будет медленно как черепаха)") | |
| # Определяем конфиг по имени файла чекпоинта | |
| # Указываем путь относительно configs/ директории в пакете sam2 | |
| checkpoint_name = os.path.basename(checkpoint_path) | |
| if "tiny" in checkpoint_name: | |
| config = "configs/sam2.1/sam2.1_hiera_t.yaml" | |
| elif "small" in checkpoint_name: | |
| config = "configs/sam2.1/sam2.1_hiera_s.yaml" | |
| elif "base_plus" in checkpoint_name: | |
| config = "configs/sam2.1/sam2.1_hiera_b+.yaml" | |
| elif "large" in checkpoint_name: | |
| config = "configs/sam2.1/sam2.1_hiera_l.yaml" | |
| else: | |
| logger.warning(f"Неизвестный тип модели, пробую tiny конфиг") | |
| config = "configs/sam2.1/sam2.1_hiera_t.yaml" | |
| logger.info(f"Загружаю модель из {checkpoint_path}") | |
| logger.info(f"Конфиг: {config}") | |
| sam2_model = build_sam2(config, checkpoint_path, device=device) | |
| predictor = SAM2ImagePredictor(sam2_model) | |
| logger.info("✓ Модель загружена успешно") | |
| except Exception as e: | |
| logger.error(f"Не удалось загрузить модель: {e}") | |
| logger.error("Убедись что SAM2 установлен (./install_sam2.sh)") | |
| raise | |
| async def lifespan(app: FastAPI): | |
| """Загружаем модель при старте, выгружаем при остановке""" | |
| # Startup | |
| checkpoint_dir = "checkpoints" | |
| if os.path.exists(checkpoint_dir): | |
| checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")] | |
| if checkpoints: | |
| checkpoint_path = os.path.join(checkpoint_dir, checkpoints[0]) | |
| load_model(checkpoint_path) | |
| else: | |
| logger.error("Нет чекпоинтов в директории checkpoints/") | |
| logger.error("Запусти: python download_model.py") | |
| else: | |
| logger.error("Директория checkpoints/ не найдена") | |
| yield # Сервер работает | |
| # Shutdown (если нужна очистка) | |
| # Создаем FastAPI приложение с lifespan | |
| app = FastAPI( | |
| title="SAM2 Segmentation API", | |
| description="API для автоматической сегментации объектов на изображениях", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Добавляем CORS для работы с веб-интерфейсом | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # В продакшене указать конкретные домены | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| """Главная страница - информация об API""" | |
| return { | |
| "message": "SAM2 Segmentation API работает", | |
| "version": "2.0.0", | |
| "web_ui": { | |
| "simple": "/web - Box промпты", | |
| "advanced": "/web/advanced - Box + Brush промпты (рисование)" | |
| }, | |
| "docs": "/docs", | |
| "endpoints": { | |
| "POST /segment": "Сегментация изображения (поддерживает points, box, mask via query params)", | |
| "POST /segment/batch": "🔥 Батчинг сегментация (JSON API для множественных объектов)", | |
| "POST /segment/auto": "Автоматическая сегментация всех объектов", | |
| "GET /health": "Проверка здоровья сервиса" | |
| } | |
| } | |
| async def web_interface(): | |
| """Веб-интерфейс для тестирования Box Prompts (простой)""" | |
| web_demo_path = os.path.join(os.path.dirname(__file__), "web_demo.html") | |
| if os.path.exists(web_demo_path): | |
| with open(web_demo_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| else: | |
| return "<h1>Веб-интерфейс не найден</h1><p>Файл web_demo.html отсутствует</p>" | |
| async def web_interface_advanced(): | |
| """Продвинутый веб-интерфейс с Box + Brush промптами""" | |
| web_demo_path = os.path.join(os.path.dirname(__file__), "web_demo_advanced.html") | |
| if os.path.exists(web_demo_path): | |
| with open(web_demo_path, "r", encoding="utf-8") as f: | |
| return f.read() | |
| else: | |
| return "<h1>Продвинутый интерфейс не найден</h1><p>Файл web_demo_advanced.html отсутствует</p>" | |
| async def health(): | |
| """Проверка что всё ок""" | |
| return { | |
| "status": "healthy" if predictor is not None else "model not loaded", | |
| "device": str(device) if device else "unknown" | |
| } | |
| def process_image(image_bytes: bytes) -> np.ndarray: | |
| """Конвертирует байты в numpy array""" | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| return np.array(image) | |
| def masks_to_coords(masks: np.ndarray, include_contours: bool = False) -> List[Dict[str, Any]]: | |
| """ | |
| Конвертирует маски в координаты bounding box и контуров. | |
| masks: (N, H, W) - N масок | |
| include_contours: если True, добавляет контуры масок | |
| """ | |
| results = [] | |
| for i, mask in enumerate(masks): | |
| # Находим координаты пикселей маски | |
| y_coords, x_coords = np.where(mask > 0) | |
| if len(x_coords) == 0: | |
| continue | |
| # Bounding box | |
| x_min, x_max = int(x_coords.min()), int(x_coords.max()) | |
| y_min, y_max = int(y_coords.min()), int(y_coords.max()) | |
| # Площадь сегмента | |
| area = int(mask.sum()) | |
| segment_data = { | |
| "segment_id": i, | |
| "bbox": { | |
| "x_min": x_min, | |
| "y_min": y_min, | |
| "x_max": x_max, | |
| "y_max": y_max, | |
| "width": x_max - x_min, | |
| "height": y_max - y_min | |
| }, | |
| "area": area, | |
| "center": { | |
| "x": int(x_coords.mean()), | |
| "y": int(y_coords.mean()) | |
| } | |
| } | |
| # Добавляем контуры если нужно | |
| if include_contours: | |
| try: | |
| # Конвертируем маску в uint8 (защита от булевых масок) | |
| if mask.dtype == bool: | |
| mask_uint8 = mask.astype(np.uint8) * 255 | |
| else: | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| # Находим контуры с иерархией для поддержки "дыр" | |
| # RETR_CCOMP: находит внешние контуры И внутренние дыры (holes) | |
| # CHAIN_APPROX_NONE: сохраняет ВСЕ точки для pixel-perfect результата | |
| contours, hierarchy = cv2.findContours(mask_uint8, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) | |
| except Exception as e: | |
| logger.warning(f"Ошибка при извлечении контуров: {e}, использую fallback") | |
| # Fallback на простое извлечение без иерархии | |
| contours, hierarchy = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| hierarchy = None | |
| # Конвертируем контуры в список точек с учетом иерархии | |
| contour_data = [] | |
| if hierarchy is not None and len(contours) > 0: | |
| hierarchy = hierarchy[0] # OpenCV возвращает hierarchy в странном формате | |
| for i, contour in enumerate(contours): | |
| try: | |
| # Небольшое упрощение только для очень больших контуров | |
| if len(contour) > 1000: | |
| arc_length = cv2.arcLength(contour, True) | |
| if arc_length > 0: # Защита от деления на 0 | |
| epsilon = 0.0005 * arc_length | |
| approx = cv2.approxPolyDP(contour, epsilon, True) | |
| else: | |
| approx = contour | |
| else: | |
| approx = contour | |
| # Конвертируем в список [x, y] | |
| points = [[int(point[0][0]), int(point[0][1])] for point in approx] | |
| if len(points) > 2: | |
| # hierarchy[i] = [Next, Previous, First_Child, Parent] | |
| # Если Parent == -1, это внешний контур | |
| # Если Parent >= 0, это дыра (hole) внутри родительского контура | |
| is_hole = hierarchy[i][3] != -1 | |
| contour_data.append({ | |
| "points": points, | |
| "is_hole": is_hole | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Ошибка при обработке контура {i}: {e}") | |
| continue | |
| else: | |
| # Fallback если hierarchy не вернулась | |
| for contour in contours: | |
| try: | |
| if len(contour) > 1000: | |
| arc_length = cv2.arcLength(contour, True) | |
| if arc_length > 0: | |
| epsilon = 0.0005 * arc_length | |
| approx = cv2.approxPolyDP(contour, epsilon, True) | |
| else: | |
| approx = contour | |
| else: | |
| approx = contour | |
| points = [[int(point[0][0]), int(point[0][1])] for point in approx] | |
| if len(points) > 2: | |
| contour_data.append({ | |
| "points": points, | |
| "is_hole": False | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Ошибка при обработке контура: {e}") | |
| continue | |
| segment_data["contours"] = contour_data if len(contour_data) > 0 else [] | |
| # Также добавляем RLE (Run-Length Encoding) для компактного представления | |
| # Это полезно если нужно восстановить точную маску | |
| segment_data["mask_rle"] = mask_to_rle(mask) | |
| results.append(segment_data) | |
| return results | |
| def mask_to_rle(mask: np.ndarray) -> Dict[str, Any]: | |
| """ | |
| Конвертирует бинарную маску в RLE (Run-Length Encoding) | |
| Компактное представление маски | |
| """ | |
| # Конвертируем в int если это bool | |
| if mask.dtype == bool: | |
| pixels = mask.astype(np.uint8).flatten() | |
| else: | |
| pixels = mask.flatten() | |
| pixels = np.concatenate([[0], pixels, [0]]) | |
| runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 | |
| runs[1::2] -= runs[::2] | |
| return { | |
| "counts": [int(x) for x in runs], # Конвертируем numpy int в Python int | |
| "size": [int(x) for x in mask.shape] # Конвертируем в Python int | |
| } | |
| def convert_to_native_types(obj): | |
| """ | |
| Рекурсивно конвертирует numpy типы в нативные Python типы | |
| Нужно для сериализации в JSON через FastAPI | |
| """ | |
| if isinstance(obj, np.integer): | |
| return int(obj) | |
| elif isinstance(obj, np.floating): | |
| return float(obj) | |
| elif isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| elif isinstance(obj, np.bool_): | |
| return bool(obj) | |
| elif isinstance(obj, dict): | |
| return {key: convert_to_native_types(value) for key, value in obj.items()} | |
| elif isinstance(obj, list): | |
| return [convert_to_native_types(item) for item in obj] | |
| return obj | |
| def clean_mask(mask: np.ndarray, min_area: int = 100) -> np.ndarray: | |
| """ | |
| Очищает маску от мелких артефактов и дыр. | |
| Более мягкий вариант - не убивает тонкие детали типа лямок. | |
| mask: бинарная маска (H, W) | |
| min_area: минимальная площадь компонента в пикселях | |
| Returns: очищенная маска | |
| """ | |
| # Конвертируем в uint8 если нужно | |
| if mask.dtype == bool: | |
| mask_uint8 = mask.astype(np.uint8) * 255 | |
| else: | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| # Только легкое закрытие для удаления мелких дыр внутри объекта | |
| # Используем маленький kernel чтобы не убить тонкие детали (лямки, пальцы и т.д.) | |
| kernel = np.ones((2, 2), np.uint8) | |
| mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel, iterations=1) | |
| # УБРАЛ MORPH_OPEN - он убивал тонкие элементы типа лямок портфеля | |
| # УБРАЛ фильтрацию по площади компонентов - она тоже могла вырезать лямки | |
| return (mask_uint8 > 127).astype(bool) | |
| def extract_object_image(image: np.ndarray, mask: np.ndarray, clean: bool = True) -> str: | |
| """ | |
| Вырезает объект из изображения по маске и возвращает base64 PNG с прозрачностью. | |
| image: RGB изображение (H, W, 3) | |
| mask: бинарная маска (H, W) | |
| clean: применить постобработку для удаления артефактов | |
| Returns: base64 строка PNG изображения с альфа-каналом | |
| """ | |
| # Конвертируем маску в bool если нужно | |
| if mask.dtype != bool: | |
| mask = mask > 0.5 | |
| # Очищаем маску от артефактов | |
| if clean: | |
| mask = clean_mask(mask, min_area=100) | |
| # Создаем RGBA изображение | |
| h, w = image.shape[:2] | |
| rgba = np.zeros((h, w, 4), dtype=np.uint8) | |
| rgba[:, :, :3] = image # RGB каналы | |
| rgba[:, :, 3] = (mask * 255).astype(np.uint8) # Alpha канал из маски | |
| # Конвертируем в PIL Image | |
| pil_image = Image.fromarray(rgba, 'RGBA') | |
| # Конвертируем в base64 | |
| buffer = io.BytesIO() | |
| pil_image.save(buffer, format='PNG') | |
| buffer.seek(0) | |
| img_base64 = base64.b64encode(buffer.read()).decode('utf-8') | |
| return f"data:image/png;base64,{img_base64}" | |
| async def segment_image( | |
| file: UploadFile = File(...), | |
| point_x: List[float] = Query(None, description="X координаты точек промпта"), | |
| point_y: List[float] = Query(None, description="Y координаты точек промпта"), | |
| point_labels: List[int] = Query(None, description="Лейблы точек (1=foreground, 0=background)"), | |
| box_x1: float = Query(None, description="X координата левого верхнего угла бокса"), | |
| box_y1: float = Query(None, description="Y координата левого верхнего угла бокса"), | |
| box_x2: float = Query(None, description="X координата правого нижнего угла бокса"), | |
| box_y2: float = Query(None, description="Y координата правого нижнего угла бокса"), | |
| mask_data: str = Query(None, description="Base64 закодированная маска (PNG с альфа-каналом)"), | |
| include_masks: bool = Query(True, description="Включить контуры масок в ответ"), | |
| extract_objects: bool = Query(False, description="Вернуть вырезанные объекты как base64 PNG"), | |
| ): | |
| """ | |
| Сегментирует изображение по промпту (точкам, боксу, маске или их комбинации). | |
| Поддерживаемые промпты: | |
| - Точки (point_x, point_y, point_labels) - клики пользователя | |
| - Бокс (box_x1, box_y1, box_x2, box_y2) - прямоугольное выделение | |
| - Маска (mask_data) - нарисованная кистью маска (зеленый=foreground, красный=background) | |
| - Комбинация промптов - для максимальной точности | |
| Если промпты не указаны, сегментирует центральный объект. | |
| Если include_masks=True, возвращает контуры масок для точной отрисовки. | |
| Если extract_objects=True, возвращает готовые вырезанные объекты как base64 PNG. | |
| """ | |
| if predictor is None: | |
| raise HTTPException(status_code=503, detail="Модель не загружена, перезапусти сервер") | |
| try: | |
| # Читаем изображение | |
| image_bytes = await file.read() | |
| image = process_image(image_bytes) | |
| logger.info(f"Обрабатываю изображение: {image.shape}") | |
| logger.info(f"Параметры: include_masks={include_masks}, extract_objects={extract_objects}") | |
| # Устанавливаем изображение в предиктор | |
| predictor.set_image(image) | |
| # Подготавливаем промпты | |
| points = None | |
| labels = None | |
| box = None | |
| # Проверяем наличие точек | |
| if point_x and point_y: | |
| if len(point_x) != len(point_y): | |
| raise HTTPException(status_code=400, detail="Количество X и Y координат должно совпадать") | |
| points = np.array([[x, y] for x, y in zip(point_x, point_y)]) | |
| labels = np.array(point_labels) if point_labels else np.ones(len(points)) | |
| logger.info(f"Промпт: {len(points)} точек") | |
| # Проверяем наличие бокса | |
| if all(v is not None for v in [box_x1, box_y1, box_x2, box_y2]): | |
| box = np.array([box_x1, box_y1, box_x2, box_y2]) | |
| logger.info(f"Промпт: бокс [{box_x1:.1f}, {box_y1:.1f}, {box_x2:.1f}, {box_y2:.1f}]") | |
| # Валидация бокса | |
| if box_x2 <= box_x1 or box_y2 <= box_y1: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Некорректный бокс: x2 должен быть больше x1, y2 больше y1" | |
| ) | |
| # Проверяем наличие нарисованной маски | |
| if mask_data: | |
| logger.info("Обрабатываю нарисованную маску...") | |
| try: | |
| # Декодируем base64 | |
| if ',' in mask_data: | |
| mask_data = mask_data.split(',')[1] # Убираем data:image/png;base64, | |
| mask_bytes = base64.b64decode(mask_data) | |
| mask_image = Image.open(io.BytesIO(mask_bytes)).convert('RGBA') | |
| mask_array = np.array(mask_image) | |
| # Извлекаем foreground и background пиксели | |
| # Поддерживаем несколько форматов: | |
| # 1. Зеленый (R<100, G>150, B<100) - классический foreground | |
| # 2. Белый/светлый (R>200, G>200, B>200) - часто используется фронтами | |
| # 3. Красный (R>150, G<100, B<100) - background | |
| green_mask = (mask_array[:, :, 0] < 100) & (mask_array[:, :, 1] > 150) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0) | |
| white_mask = (mask_array[:, :, 0] > 200) & (mask_array[:, :, 1] > 200) & (mask_array[:, :, 2] > 200) & (mask_array[:, :, 3] > 0) | |
| red_mask = (mask_array[:, :, 0] > 150) & (mask_array[:, :, 1] < 100) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0) | |
| # Объединяем зеленые и белые как foreground | |
| foreground_mask = green_mask | white_mask | |
| # Сэмплируем точки из закрашенных областей | |
| mask_points = [] | |
| mask_labels = [] | |
| # Foreground точки (зеленые + белые) | |
| foreground_coords = np.argwhere(foreground_mask) | |
| if len(foreground_coords) > 0: | |
| # Масштабируем к размеру исходного изображения | |
| scale_y = image.shape[0] / mask_array.shape[0] | |
| scale_x = image.shape[1] / mask_array.shape[1] | |
| # Сэмплируем до 20 точек равномерно (меньше = стабильнее) | |
| step = max(1, len(foreground_coords) // 20) | |
| sampled = foreground_coords[::step][:20] # Максимум 20 точек | |
| for y, x in sampled: | |
| mask_points.append([x * scale_x, y * scale_y]) | |
| mask_labels.append(1) # foreground | |
| # Background точки (красные) | |
| red_coords = np.argwhere(red_mask) | |
| if len(red_coords) > 0: | |
| scale_y = image.shape[0] / mask_array.shape[0] | |
| scale_x = image.shape[1] / mask_array.shape[1] | |
| step = max(1, len(red_coords) // 20) | |
| sampled = red_coords[::step][:20] # Максимум 20 точек | |
| for y, x in sampled: | |
| mask_points.append([x * scale_x, y * scale_y]) | |
| mask_labels.append(0) # background | |
| if mask_points: | |
| # Объединяем с существующими точками | |
| if points is not None: | |
| points = np.vstack([points, np.array(mask_points)]) | |
| labels = np.concatenate([labels, np.array(mask_labels)]) | |
| else: | |
| points = np.array(mask_points) | |
| labels = np.array(mask_labels) | |
| logger.info(f"Промпт из маски: {len(mask_points)} точек ({np.sum(np.array(mask_labels) == 1)} foreground, {np.sum(np.array(mask_labels) == 0)} background)") | |
| else: | |
| logger.warning("Маска пустая или не содержит foreground (зеленых/белых) или background (красных) пикселей") | |
| except Exception as e: | |
| logger.error(f"Ошибка обработки маски: {e}") | |
| raise HTTPException(status_code=400, detail=f"Некорректная маска: {str(e)}") | |
| # Делаем предсказание с промптами | |
| if points is not None or box is not None: | |
| logger.info(f"Используем промпты: points={points is not None}, box={box is not None}") | |
| # Если много точек (>10), используем single mask для стабильности | |
| # Если мало точек или только box, используем multimask для вариативности | |
| use_multimask = True | |
| if points is not None and len(points) > 10: | |
| use_multimask = False | |
| logger.info("Много точек, используем single mask mode для стабильности") | |
| masks, scores, logits = predictor.predict( | |
| point_coords=points, | |
| point_labels=labels, | |
| box=box, | |
| multimask_output=use_multimask, | |
| ) | |
| # Если multimask, выбираем лучшую по score | |
| if use_multimask and len(masks) > 1: | |
| best_idx = np.argmax(scores) | |
| masks = masks[best_idx:best_idx+1] | |
| scores = scores[best_idx:best_idx+1] | |
| logger.info(f"Выбрана маска {best_idx} с confidence {scores[0]:.3f}") | |
| else: | |
| # Автоматическая сегментация - берем центральную точку | |
| logger.info("Промпты не указаны, сегментирую центральный объект") | |
| h, w = image.shape[:2] | |
| point = np.array([[w // 2, h // 2]]) | |
| label = np.array([1]) | |
| masks, scores, logits = predictor.predict( | |
| point_coords=point, | |
| point_labels=label, | |
| multimask_output=True, | |
| ) | |
| # Конвертируем маски в координаты (с контурами если нужно) | |
| segments = masks_to_coords(masks, include_contours=include_masks) | |
| logger.info(f"Найдено сегментов: {len(segments)}, масок: {len(masks)}") | |
| logger.info(f"extract_objects = {extract_objects}") | |
| # Добавляем confidence scores | |
| for i, seg in enumerate(segments): | |
| seg["confidence"] = float(scores[i]) if i < len(scores) else 0.0 | |
| # Если нужно - вырезаем объект и добавляем base64 | |
| logger.info(f"Обрабатываю сегмент {i}: extract_objects={extract_objects}, i < len(masks) = {i < len(masks)}") | |
| if extract_objects and i < len(masks): | |
| logger.info(f"Вырезаю объект {i}...") | |
| seg["extracted_image"] = extract_object_image(image, masks[i]) | |
| logger.info(f"✓ Вырезан объект {i}, размер маски: {masks[i].sum()} пикселей") | |
| else: | |
| logger.warning(f"❌ Пропускаю объект {i}: extract_objects={extract_objects}") | |
| result = { | |
| "success": True, | |
| "image_size": { | |
| "width": int(image.shape[1]), | |
| "height": int(image.shape[0]) | |
| }, | |
| "segments_count": len(segments), | |
| "segments": segments | |
| } | |
| # Конвертируем все numpy типы в нативные Python типы | |
| return convert_to_native_types(result) | |
| except Exception as e: | |
| logger.error(f"Ошибка при сегментации: {e}") | |
| raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}") | |
| async def segment_auto( | |
| file: UploadFile = File(...), | |
| points_per_side: int = Query(32, description="Количество точек на сторону для автосегментации"), | |
| include_masks: bool = Query(True, description="Включить контуры масок в ответ"), | |
| ): | |
| """ | |
| Автоматическая сегментация всех объектов на изображении. | |
| Использует grid of points для поиска всех возможных объектов. | |
| Если include_masks=True, возвращает контуры масок для точной отрисовки. | |
| """ | |
| if predictor is None: | |
| raise HTTPException(status_code=503, detail="Модель не загружена") | |
| try: | |
| image_bytes = await file.read() | |
| image = process_image(image_bytes) | |
| logger.info(f"Автосегментация изображения: {image.shape}") | |
| predictor.set_image(image) | |
| # Создаем сетку точек | |
| h, w = image.shape[:2] | |
| x_coords = np.linspace(0, w, points_per_side) | |
| y_coords = np.linspace(0, h, points_per_side) | |
| all_segments = [] | |
| segment_id = 0 | |
| # Для каждой точки в сетке пытаемся найти объект | |
| for y in y_coords: | |
| for x in x_coords: | |
| point = np.array([[x, y]]) | |
| label = np.array([1]) | |
| masks, scores, _ = predictor.predict( | |
| point_coords=point, | |
| point_labels=label, | |
| multimask_output=False, | |
| ) | |
| if masks.shape[0] > 0 and scores[0] > 0.5: # Порог confidence | |
| segments = masks_to_coords(masks, include_contours=include_masks) | |
| for seg in segments: | |
| seg["segment_id"] = segment_id | |
| seg["confidence"] = float(scores[0]) | |
| all_segments.append(seg) | |
| segment_id += 1 | |
| # Убираем дубликаты (примерно) | |
| # Два сегмента считаем дубликатами если их центры близко | |
| unique_segments = [] | |
| for seg in all_segments: | |
| is_duplicate = False | |
| for unique_seg in unique_segments: | |
| dx = seg["center"]["x"] - unique_seg["center"]["x"] | |
| dy = seg["center"]["y"] - unique_seg["center"]["y"] | |
| dist = (dx**2 + dy**2) ** 0.5 | |
| if dist < 50: # Порог расстояния между центрами | |
| is_duplicate = True | |
| break | |
| if not is_duplicate: | |
| unique_segments.append(seg) | |
| result = { | |
| "success": True, | |
| "image_size": { | |
| "width": int(image.shape[1]), | |
| "height": int(image.shape[0]) | |
| }, | |
| "segments_count": len(unique_segments), | |
| "segments": unique_segments | |
| } | |
| # Конвертируем все numpy типы в нативные Python типы | |
| return convert_to_native_types(result) | |
| except Exception as e: | |
| logger.error(f"Ошибка при автосегментации: {e}") | |
| raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}") | |
| async def segment_batch(request: BatchSegmentRequest = Body(...)): | |
| """ | |
| Батчинг сегментация нескольких объектов. | |
| Принимает изображение и массив промптов (mask/box/points). | |
| Обрабатывает каждый selected промпт отдельно. | |
| Возвращает массив результатов с метаданными. | |
| Идеально для: | |
| - Множественных объектов | |
| - Мобильных приложений | |
| - Когда фронт уже разделил объекты | |
| """ | |
| if predictor is None: | |
| raise HTTPException(status_code=503, detail="Модель не загружена, перезапусти сервер") | |
| try: | |
| # Декодируем изображение из base64 | |
| image_data = request.image | |
| if ',' in image_data: | |
| image_data = image_data.split(',')[1] # Убираем data:image/...;base64, | |
| image_bytes = base64.b64decode(image_data) | |
| image = process_image(image_bytes) | |
| logger.info(f"Батчинг сегментация: {image.shape}, промптов: {len(request.prompts)}") | |
| # Устанавливаем изображение один раз | |
| predictor.set_image(image) | |
| results = [] | |
| # Фильтруем только selected промпты | |
| selected_prompts = [p for p in request.prompts if p.selected] | |
| logger.info(f"Обрабатываем {len(selected_prompts)} из {len(request.prompts)} промптов") | |
| # Обрабатываем каждый промпт отдельно | |
| for prompt in selected_prompts: | |
| logger.info(f"Обрабатываю промпт #{prompt.id}, тип: {prompt.type}, label: {prompt.label}") | |
| try: | |
| # Подготавливаем промпт в зависимости от типа | |
| points = None | |
| labels = None | |
| box = None | |
| if prompt.type == "mask": | |
| # Декодируем маску и извлекаем точки | |
| mask_data = prompt.data | |
| if ',' in mask_data: | |
| mask_data = mask_data.split(',')[1] | |
| mask_bytes = base64.b64decode(mask_data) | |
| mask_image = Image.open(io.BytesIO(mask_bytes)).convert('RGBA') | |
| mask_array = np.array(mask_image) | |
| # Извлекаем foreground и background пиксели | |
| # Поддерживаем несколько форматов: | |
| # 1. Зеленый (R<100, G>150, B<100) - классический foreground | |
| # 2. Белый/светлый (R>200, G>200, B>200) - часто используется фронтами | |
| # 3. Красный (R>150, G<100, B<100) - background | |
| green_mask = (mask_array[:, :, 0] < 100) & (mask_array[:, :, 1] > 150) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0) | |
| white_mask = (mask_array[:, :, 0] > 200) & (mask_array[:, :, 1] > 200) & (mask_array[:, :, 2] > 200) & (mask_array[:, :, 3] > 0) | |
| red_mask = (mask_array[:, :, 0] > 150) & (mask_array[:, :, 1] < 100) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0) | |
| # Объединяем зеленые и белые как foreground | |
| foreground_mask = green_mask | white_mask | |
| mask_points = [] | |
| mask_labels = [] | |
| # Foreground точки (зеленые + белые) | |
| foreground_coords = np.argwhere(foreground_mask) | |
| if len(foreground_coords) > 0: | |
| scale_y = image.shape[0] / mask_array.shape[0] | |
| scale_x = image.shape[1] / mask_array.shape[1] | |
| step = max(1, len(foreground_coords) // 20) | |
| sampled = foreground_coords[::step][:20] | |
| for y, x in sampled: | |
| mask_points.append([x * scale_x, y * scale_y]) | |
| mask_labels.append(1) | |
| # Background точки | |
| red_coords = np.argwhere(red_mask) | |
| if len(red_coords) > 0: | |
| scale_y = image.shape[0] / mask_array.shape[0] | |
| scale_x = image.shape[1] / mask_array.shape[1] | |
| step = max(1, len(red_coords) // 20) | |
| sampled = red_coords[::step][:20] | |
| for y, x in sampled: | |
| mask_points.append([x * scale_x, y * scale_y]) | |
| mask_labels.append(0) | |
| if mask_points: | |
| points = np.array(mask_points) | |
| labels = np.array(mask_labels) | |
| elif prompt.type == "box": | |
| # Парсим bbox - может быть нормализованный (0-1) или пиксельный | |
| bbox_data = prompt.bbox if prompt.bbox else None | |
| if bbox_data: | |
| x1 = bbox_data.x_min | |
| y1 = bbox_data.y_min | |
| x2 = bbox_data.x_max | |
| y2 = bbox_data.y_max | |
| # Если нормализованные координаты (0-1), конвертируем в пиксели | |
| if x2 <= 1.0 and y2 <= 1.0: | |
| x1 *= image.shape[1] | |
| x2 *= image.shape[1] | |
| y1 *= image.shape[0] | |
| y2 *= image.shape[0] | |
| box = np.array([x1, y1, x2, y2]) | |
| elif prompt.type == "points": | |
| # Ожидаем JSON в формате [[x, y, label], ...] | |
| import json | |
| points_data = json.loads(prompt.data) | |
| points_list = [] | |
| labels_list = [] | |
| for point in points_data: | |
| x, y = point[0], point[1] | |
| label = point[2] if len(point) > 2 else 1 | |
| # Если нормализованные, конвертируем | |
| if x <= 1.0 and y <= 1.0: | |
| x *= image.shape[1] | |
| y *= image.shape[0] | |
| points_list.append([x, y]) | |
| labels_list.append(label) | |
| points = np.array(points_list) | |
| labels = np.array(labels_list) | |
| # Делаем предсказание | |
| if points is not None or box is not None: | |
| # Решаем использовать ли multimask | |
| use_multimask = True | |
| if points is not None and len(points) > 10: | |
| use_multimask = False | |
| masks, scores, logits = predictor.predict( | |
| point_coords=points, | |
| point_labels=labels, | |
| box=box, | |
| multimask_output=use_multimask, | |
| ) | |
| # Если multimask, выбираем лучшую | |
| if use_multimask and len(masks) > 1: | |
| best_idx = np.argmax(scores) | |
| masks = masks[best_idx:best_idx+1] | |
| scores = scores[best_idx:best_idx+1] | |
| # Берем первую маску | |
| mask = masks[0] | |
| score = float(scores[0]) | |
| # Очищаем маску если нужно | |
| if request.options.clean_masks: | |
| mask = clean_mask(mask, min_area=100) | |
| # Вычисляем метрики | |
| y_coords, x_coords = np.where(mask > 0) | |
| if len(x_coords) > 0: | |
| x_min, x_max = int(x_coords.min()), int(x_coords.max()) | |
| y_min, y_max = int(y_coords.min()), int(y_coords.max()) | |
| area = int(mask.sum()) | |
| center_x = int(x_coords.mean()) | |
| center_y = int(y_coords.mean()) | |
| # Формируем результат | |
| result = { | |
| "id": prompt.id, | |
| "label": prompt.label, | |
| "bbox": { | |
| "x_min": x_min, | |
| "y_min": y_min, | |
| "x_max": x_max, | |
| "y_max": y_max, | |
| "width": x_max - x_min, | |
| "height": y_max - y_min | |
| }, | |
| "area": area, | |
| "center": { | |
| "x": center_x, | |
| "y": center_y | |
| }, | |
| "confidence": score | |
| } | |
| # Добавляем вырезанный объект если нужно | |
| if request.options.extract_objects: | |
| result["extracted_image"] = extract_object_image( | |
| image, mask, clean=request.options.clean_masks | |
| ) | |
| # Добавляем контуры если нужно | |
| if request.options.include_masks: | |
| segments = masks_to_coords(masks, include_contours=True) | |
| if segments: | |
| result["contours"] = segments[0].get("contours", []) | |
| result["mask_rle"] = segments[0].get("mask_rle", {}) | |
| results.append(result) | |
| logger.info(f"✓ Промпт #{prompt.id} обработан, confidence: {score:.3f}") | |
| else: | |
| logger.warning(f"✗ Промпт #{prompt.id} не дал результата") | |
| else: | |
| logger.warning(f"✗ Промпт #{prompt.id}: нет данных для сегментации") | |
| except Exception as e: | |
| logger.error(f"✗ Ошибка обработки промпта #{prompt.id}: {e}") | |
| # Продолжаем обработку остальных промптов | |
| continue | |
| response = { | |
| "success": True, | |
| "image_size": { | |
| "width": int(image.shape[1]), | |
| "height": int(image.shape[0]) | |
| }, | |
| "results": results | |
| } | |
| logger.info(f"Батчинг завершен: обработано {len(results)} объектов") | |
| # Сохраняем лог запроса для аудита (только метаданные, без изображений) | |
| try: | |
| request_dict = request.dict() | |
| save_batch_request_log(request_dict, response, image.shape[1], image.shape[0]) | |
| except Exception as e: | |
| logger.warning(f"Не удалось сохранить лог запроса: {e}") | |
| return convert_to_native_types(response) | |
| except Exception as e: | |
| logger.error(f"Ошибка при батчинг сегментации: {e}") | |
| raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| # Порт из переменной окружения (для HF Spaces) или 8000 по умолчанию | |
| port = int(os.getenv("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |