""" REST API сервер для сегментации изображений через SAM2. Уставший сеньор кодит это в 3 часа ночи, поэтому код местами будет грязный. """ from contextlib import asynccontextmanager from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Body from fastapi.responses import JSONResponse, HTMLResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from PIL import Image import numpy as np import torch import io import os import base64 import cv2 from typing import List, Dict, Any, Optional, Literal import logging from datetime import datetime import json # Настройка логирования, потому что дебажить это говно иначе невозможно logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Глобальные переменные для модели (лень каждый раз загружать) predictor = None device = None # ===== Pydantic модели для батчинг API ===== class BBoxModel(BaseModel): """Bounding box в нормализованных координатах (0.0 - 1.0) или пиксельных""" x_min: float = Field(..., description="X координата левого верхнего угла") y_min: float = Field(..., description="Y координата левого верхнего угла") x_max: float = Field(..., description="X координата правого нижнего угла") y_max: float = Field(..., description="Y координата правого нижнего угла") class PromptModel(BaseModel): """Промпт для сегментации одного объекта""" id: int = Field(..., description="Уникальный ID объекта") type: Literal["mask", "box", "points"] = Field(..., description="Тип промпта") data: str = Field(..., description="Данные промпта (base64 для mask, JSON для points)") bbox: Optional[BBoxModel] = Field(None, description="Опциональный bounding box") label: Optional[str] = Field(None, description="Метка объекта (person, car, etc)") selected: bool = Field(True, description="Обрабатывать ли этот промпт") class SegmentOptionsModel(BaseModel): """Опции сегментации""" extract_objects: bool = Field(True, description="Вернуть вырезанные объекты") include_masks: bool = Field(False, description="Включить контуры масок") clean_masks: bool = Field(True, description="Очистить маски от артефактов") class BatchSegmentRequest(BaseModel): """Запрос на батчинг сегментацию""" image: str = Field(..., description="Изображение в base64 (с data URL или без)") prompts: List[PromptModel] = Field(..., description="Массив промптов") options: Optional[SegmentOptionsModel] = Field(default_factory=SegmentOptionsModel) class SegmentResultModel(BaseModel): """Результат сегментации одного объекта""" id: int label: Optional[str] = None bbox: Dict[str, Any] area: int center: Dict[str, int] confidence: float extracted_image: Optional[str] = None contours: Optional[List[Dict[str, Any]]] = None mask_rle: Optional[Dict[str, Any]] = None class BatchSegmentResponse(BaseModel): """Ответ батчинг сегментации""" success: bool image_size: Dict[str, int] results: List[SegmentResultModel] def save_batch_request_log(request_data: dict, response_data: dict, image_width: int, image_height: int): """ Сохраняет запрос батчинга для аудита и дебага. Создает папку с timestamp и сохраняет только метаданные: 1. Лог запроса (request.json) - параметры без base64 2. Лог ответа (response.json) - результаты без base64 3. Краткую сводку (summary.json) ⚠️ Изображения и маски НЕ сохраняются для безопасности! """ try: # Создаем корневую папку для логов logs_dir = "batch_logs" os.makedirs(logs_dir, exist_ok=True) # Создаем папку с timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] # Миллисекунды request_dir = os.path.join(logs_dir, timestamp) os.makedirs(request_dir, exist_ok=True) logger.info(f"📁 Сохраняю лог запроса в: {request_dir}") # Сохраняем запрос (без base64 для безопасности) request_log = { "timestamp": timestamp, "image_size": { "width": image_width, "height": image_height }, "prompts": [ { "id": p.get("id"), "type": p.get("type"), "label": p.get("label"), "bbox": p.get("bbox"), "selected": p.get("selected"), "data_length": len(p.get("data", "")) # Длина вместо самих данных } for p in request_data.get("prompts", []) ], "options": request_data.get("options", {}) } request_path = os.path.join(request_dir, "request.json") with open(request_path, "w", encoding="utf-8") as f: json.dump(request_log, f, indent=2, ensure_ascii=False) logger.info(f" ✓ Сохранен лог запроса: {request_path}") # 4. Сохраняем ответ (без base64 объектов) response_log = { "timestamp": timestamp, "success": response_data.get("success"), "image_size": response_data.get("image_size"), "results": [ { "id": r.get("id"), "label": r.get("label"), "bbox": r.get("bbox"), "area": r.get("area"), "center": r.get("center"), "confidence": r.get("confidence"), "has_extracted_image": "extracted_image" in r, "has_contours": "contours" in r } for r in response_data.get("results", []) ] } response_path = os.path.join(request_dir, "response.json") with open(response_path, "w", encoding="utf-8") as f: json.dump(response_log, f, indent=2, ensure_ascii=False) logger.info(f" ✓ Сохранен лог ответа: {response_path}") # 3. Создаем summary файл summary = { "timestamp": timestamp, "processed_prompts": len(response_data.get("results", [])), "total_prompts": len(request_data.get("prompts", [])), "selected_prompts": len([p for p in request_data.get("prompts", []) if p.get("selected", True)]), "image_size": f"{image_width}x{image_height}", "prompt_types": [p.get("type") for p in request_data.get("prompts", [])], "files": { "request": "request.json", "response": "response.json" } } summary_path = os.path.join(request_dir, "summary.json") with open(summary_path, "w", encoding="utf-8") as f: json.dump(summary, f, indent=2, ensure_ascii=False) logger.info(f"✅ Лог запроса сохранен: {request_dir}") except Exception as e: logger.error(f"❌ Ошибка при сохранении лога: {e}") # Не прерываем обработку запроса если не удалось сохранить лог def load_model(checkpoint_path: str = "checkpoints/sam2.1_hiera_tiny.pt"): """ Загружает модель SAM2. Вызывается один раз при старте сервера. """ global predictor, device try: from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor # Проверяем CUDA device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Используем устройство: {device}") if device == "cpu": logger.warning("CUDA недоступна, работаем на CPU (будет медленно как черепаха)") # Определяем конфиг по имени файла чекпоинта # Указываем путь относительно configs/ директории в пакете sam2 checkpoint_name = os.path.basename(checkpoint_path) if "tiny" in checkpoint_name: config = "configs/sam2.1/sam2.1_hiera_t.yaml" elif "small" in checkpoint_name: config = "configs/sam2.1/sam2.1_hiera_s.yaml" elif "base_plus" in checkpoint_name: config = "configs/sam2.1/sam2.1_hiera_b+.yaml" elif "large" in checkpoint_name: config = "configs/sam2.1/sam2.1_hiera_l.yaml" else: logger.warning(f"Неизвестный тип модели, пробую tiny конфиг") config = "configs/sam2.1/sam2.1_hiera_t.yaml" logger.info(f"Загружаю модель из {checkpoint_path}") logger.info(f"Конфиг: {config}") sam2_model = build_sam2(config, checkpoint_path, device=device) predictor = SAM2ImagePredictor(sam2_model) logger.info("✓ Модель загружена успешно") except Exception as e: logger.error(f"Не удалось загрузить модель: {e}") logger.error("Убедись что SAM2 установлен (./install_sam2.sh)") raise @asynccontextmanager async def lifespan(app: FastAPI): """Загружаем модель при старте, выгружаем при остановке""" # Startup checkpoint_dir = "checkpoints" if os.path.exists(checkpoint_dir): checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pt")] if checkpoints: checkpoint_path = os.path.join(checkpoint_dir, checkpoints[0]) load_model(checkpoint_path) else: logger.error("Нет чекпоинтов в директории checkpoints/") logger.error("Запусти: python download_model.py") else: logger.error("Директория checkpoints/ не найдена") yield # Сервер работает # Shutdown (если нужна очистка) # Создаем FastAPI приложение с lifespan app = FastAPI( title="SAM2 Segmentation API", description="API для автоматической сегментации объектов на изображениях", version="1.0.0", lifespan=lifespan ) # Добавляем CORS для работы с веб-интерфейсом app.add_middleware( CORSMiddleware, allow_origins=["*"], # В продакшене указать конкретные домены allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """Главная страница - информация об API""" return { "message": "SAM2 Segmentation API работает", "version": "2.0.0", "web_ui": { "simple": "/web - Box промпты", "advanced": "/web/advanced - Box + Brush промпты (рисование)" }, "docs": "/docs", "endpoints": { "POST /segment": "Сегментация изображения (поддерживает points, box, mask via query params)", "POST /segment/batch": "🔥 Батчинг сегментация (JSON API для множественных объектов)", "POST /segment/auto": "Автоматическая сегментация всех объектов", "GET /health": "Проверка здоровья сервиса" } } @app.get("/web", response_class=HTMLResponse) async def web_interface(): """Веб-интерфейс для тестирования Box Prompts (простой)""" web_demo_path = os.path.join(os.path.dirname(__file__), "web_demo.html") if os.path.exists(web_demo_path): with open(web_demo_path, "r", encoding="utf-8") as f: return f.read() else: return "
Файл web_demo.html отсутствует
" @app.get("/web/advanced", response_class=HTMLResponse) async def web_interface_advanced(): """Продвинутый веб-интерфейс с Box + Brush промптами""" web_demo_path = os.path.join(os.path.dirname(__file__), "web_demo_advanced.html") if os.path.exists(web_demo_path): with open(web_demo_path, "r", encoding="utf-8") as f: return f.read() else: return "Файл web_demo_advanced.html отсутствует
" @app.get("/health") async def health(): """Проверка что всё ок""" return { "status": "healthy" if predictor is not None else "model not loaded", "device": str(device) if device else "unknown" } def process_image(image_bytes: bytes) -> np.ndarray: """Конвертирует байты в numpy array""" image = Image.open(io.BytesIO(image_bytes)) if image.mode != "RGB": image = image.convert("RGB") return np.array(image) def masks_to_coords(masks: np.ndarray, include_contours: bool = False) -> List[Dict[str, Any]]: """ Конвертирует маски в координаты bounding box и контуров. masks: (N, H, W) - N масок include_contours: если True, добавляет контуры масок """ results = [] for i, mask in enumerate(masks): # Находим координаты пикселей маски y_coords, x_coords = np.where(mask > 0) if len(x_coords) == 0: continue # Bounding box x_min, x_max = int(x_coords.min()), int(x_coords.max()) y_min, y_max = int(y_coords.min()), int(y_coords.max()) # Площадь сегмента area = int(mask.sum()) segment_data = { "segment_id": i, "bbox": { "x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max, "width": x_max - x_min, "height": y_max - y_min }, "area": area, "center": { "x": int(x_coords.mean()), "y": int(y_coords.mean()) } } # Добавляем контуры если нужно if include_contours: try: # Конвертируем маску в uint8 (защита от булевых масок) if mask.dtype == bool: mask_uint8 = mask.astype(np.uint8) * 255 else: mask_uint8 = (mask * 255).astype(np.uint8) # Находим контуры с иерархией для поддержки "дыр" # RETR_CCOMP: находит внешние контуры И внутренние дыры (holes) # CHAIN_APPROX_NONE: сохраняет ВСЕ точки для pixel-perfect результата contours, hierarchy = cv2.findContours(mask_uint8, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) except Exception as e: logger.warning(f"Ошибка при извлечении контуров: {e}, использую fallback") # Fallback на простое извлечение без иерархии contours, hierarchy = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) hierarchy = None # Конвертируем контуры в список точек с учетом иерархии contour_data = [] if hierarchy is not None and len(contours) > 0: hierarchy = hierarchy[0] # OpenCV возвращает hierarchy в странном формате for i, contour in enumerate(contours): try: # Небольшое упрощение только для очень больших контуров if len(contour) > 1000: arc_length = cv2.arcLength(contour, True) if arc_length > 0: # Защита от деления на 0 epsilon = 0.0005 * arc_length approx = cv2.approxPolyDP(contour, epsilon, True) else: approx = contour else: approx = contour # Конвертируем в список [x, y] points = [[int(point[0][0]), int(point[0][1])] for point in approx] if len(points) > 2: # hierarchy[i] = [Next, Previous, First_Child, Parent] # Если Parent == -1, это внешний контур # Если Parent >= 0, это дыра (hole) внутри родительского контура is_hole = hierarchy[i][3] != -1 contour_data.append({ "points": points, "is_hole": is_hole }) except Exception as e: logger.warning(f"Ошибка при обработке контура {i}: {e}") continue else: # Fallback если hierarchy не вернулась for contour in contours: try: if len(contour) > 1000: arc_length = cv2.arcLength(contour, True) if arc_length > 0: epsilon = 0.0005 * arc_length approx = cv2.approxPolyDP(contour, epsilon, True) else: approx = contour else: approx = contour points = [[int(point[0][0]), int(point[0][1])] for point in approx] if len(points) > 2: contour_data.append({ "points": points, "is_hole": False }) except Exception as e: logger.warning(f"Ошибка при обработке контура: {e}") continue segment_data["contours"] = contour_data if len(contour_data) > 0 else [] # Также добавляем RLE (Run-Length Encoding) для компактного представления # Это полезно если нужно восстановить точную маску segment_data["mask_rle"] = mask_to_rle(mask) results.append(segment_data) return results def mask_to_rle(mask: np.ndarray) -> Dict[str, Any]: """ Конвертирует бинарную маску в RLE (Run-Length Encoding) Компактное представление маски """ # Конвертируем в int если это bool if mask.dtype == bool: pixels = mask.astype(np.uint8).flatten() else: pixels = mask.flatten() pixels = np.concatenate([[0], pixels, [0]]) runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 runs[1::2] -= runs[::2] return { "counts": [int(x) for x in runs], # Конвертируем numpy int в Python int "size": [int(x) for x in mask.shape] # Конвертируем в Python int } def convert_to_native_types(obj): """ Рекурсивно конвертирует numpy типы в нативные Python типы Нужно для сериализации в JSON через FastAPI """ if isinstance(obj, np.integer): return int(obj) elif isinstance(obj, np.floating): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, np.bool_): return bool(obj) elif isinstance(obj, dict): return {key: convert_to_native_types(value) for key, value in obj.items()} elif isinstance(obj, list): return [convert_to_native_types(item) for item in obj] return obj def clean_mask(mask: np.ndarray, min_area: int = 100) -> np.ndarray: """ Очищает маску от мелких артефактов и дыр. Более мягкий вариант - не убивает тонкие детали типа лямок. mask: бинарная маска (H, W) min_area: минимальная площадь компонента в пикселях Returns: очищенная маска """ # Конвертируем в uint8 если нужно if mask.dtype == bool: mask_uint8 = mask.astype(np.uint8) * 255 else: mask_uint8 = (mask * 255).astype(np.uint8) # Только легкое закрытие для удаления мелких дыр внутри объекта # Используем маленький kernel чтобы не убить тонкие детали (лямки, пальцы и т.д.) kernel = np.ones((2, 2), np.uint8) mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel, iterations=1) # УБРАЛ MORPH_OPEN - он убивал тонкие элементы типа лямок портфеля # УБРАЛ фильтрацию по площади компонентов - она тоже могла вырезать лямки return (mask_uint8 > 127).astype(bool) def extract_object_image(image: np.ndarray, mask: np.ndarray, clean: bool = True) -> str: """ Вырезает объект из изображения по маске и возвращает base64 PNG с прозрачностью. image: RGB изображение (H, W, 3) mask: бинарная маска (H, W) clean: применить постобработку для удаления артефактов Returns: base64 строка PNG изображения с альфа-каналом """ # Конвертируем маску в bool если нужно if mask.dtype != bool: mask = mask > 0.5 # Очищаем маску от артефактов if clean: mask = clean_mask(mask, min_area=100) # Создаем RGBA изображение h, w = image.shape[:2] rgba = np.zeros((h, w, 4), dtype=np.uint8) rgba[:, :, :3] = image # RGB каналы rgba[:, :, 3] = (mask * 255).astype(np.uint8) # Alpha канал из маски # Конвертируем в PIL Image pil_image = Image.fromarray(rgba, 'RGBA') # Конвертируем в base64 buffer = io.BytesIO() pil_image.save(buffer, format='PNG') buffer.seek(0) img_base64 = base64.b64encode(buffer.read()).decode('utf-8') return f"data:image/png;base64,{img_base64}" @app.post("/segment") async def segment_image( file: UploadFile = File(...), point_x: List[float] = Query(None, description="X координаты точек промпта"), point_y: List[float] = Query(None, description="Y координаты точек промпта"), point_labels: List[int] = Query(None, description="Лейблы точек (1=foreground, 0=background)"), box_x1: float = Query(None, description="X координата левого верхнего угла бокса"), box_y1: float = Query(None, description="Y координата левого верхнего угла бокса"), box_x2: float = Query(None, description="X координата правого нижнего угла бокса"), box_y2: float = Query(None, description="Y координата правого нижнего угла бокса"), mask_data: str = Query(None, description="Base64 закодированная маска (PNG с альфа-каналом)"), include_masks: bool = Query(True, description="Включить контуры масок в ответ"), extract_objects: bool = Query(False, description="Вернуть вырезанные объекты как base64 PNG"), ): """ Сегментирует изображение по промпту (точкам, боксу, маске или их комбинации). Поддерживаемые промпты: - Точки (point_x, point_y, point_labels) - клики пользователя - Бокс (box_x1, box_y1, box_x2, box_y2) - прямоугольное выделение - Маска (mask_data) - нарисованная кистью маска (зеленый=foreground, красный=background) - Комбинация промптов - для максимальной точности Если промпты не указаны, сегментирует центральный объект. Если include_masks=True, возвращает контуры масок для точной отрисовки. Если extract_objects=True, возвращает готовые вырезанные объекты как base64 PNG. """ if predictor is None: raise HTTPException(status_code=503, detail="Модель не загружена, перезапусти сервер") try: # Читаем изображение image_bytes = await file.read() image = process_image(image_bytes) logger.info(f"Обрабатываю изображение: {image.shape}") logger.info(f"Параметры: include_masks={include_masks}, extract_objects={extract_objects}") # Устанавливаем изображение в предиктор predictor.set_image(image) # Подготавливаем промпты points = None labels = None box = None # Проверяем наличие точек if point_x and point_y: if len(point_x) != len(point_y): raise HTTPException(status_code=400, detail="Количество X и Y координат должно совпадать") points = np.array([[x, y] for x, y in zip(point_x, point_y)]) labels = np.array(point_labels) if point_labels else np.ones(len(points)) logger.info(f"Промпт: {len(points)} точек") # Проверяем наличие бокса if all(v is not None for v in [box_x1, box_y1, box_x2, box_y2]): box = np.array([box_x1, box_y1, box_x2, box_y2]) logger.info(f"Промпт: бокс [{box_x1:.1f}, {box_y1:.1f}, {box_x2:.1f}, {box_y2:.1f}]") # Валидация бокса if box_x2 <= box_x1 or box_y2 <= box_y1: raise HTTPException( status_code=400, detail="Некорректный бокс: x2 должен быть больше x1, y2 больше y1" ) # Проверяем наличие нарисованной маски if mask_data: logger.info("Обрабатываю нарисованную маску...") try: # Декодируем base64 if ',' in mask_data: mask_data = mask_data.split(',')[1] # Убираем data:image/png;base64, mask_bytes = base64.b64decode(mask_data) mask_image = Image.open(io.BytesIO(mask_bytes)).convert('RGBA') mask_array = np.array(mask_image) # Извлекаем foreground и background пиксели # Поддерживаем несколько форматов: # 1. Зеленый (R<100, G>150, B<100) - классический foreground # 2. Белый/светлый (R>200, G>200, B>200) - часто используется фронтами # 3. Красный (R>150, G<100, B<100) - background green_mask = (mask_array[:, :, 0] < 100) & (mask_array[:, :, 1] > 150) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0) white_mask = (mask_array[:, :, 0] > 200) & (mask_array[:, :, 1] > 200) & (mask_array[:, :, 2] > 200) & (mask_array[:, :, 3] > 0) red_mask = (mask_array[:, :, 0] > 150) & (mask_array[:, :, 1] < 100) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0) # Объединяем зеленые и белые как foreground foreground_mask = green_mask | white_mask # Сэмплируем точки из закрашенных областей mask_points = [] mask_labels = [] # Foreground точки (зеленые + белые) foreground_coords = np.argwhere(foreground_mask) if len(foreground_coords) > 0: # Масштабируем к размеру исходного изображения scale_y = image.shape[0] / mask_array.shape[0] scale_x = image.shape[1] / mask_array.shape[1] # Сэмплируем до 20 точек равномерно (меньше = стабильнее) step = max(1, len(foreground_coords) // 20) sampled = foreground_coords[::step][:20] # Максимум 20 точек for y, x in sampled: mask_points.append([x * scale_x, y * scale_y]) mask_labels.append(1) # foreground # Background точки (красные) red_coords = np.argwhere(red_mask) if len(red_coords) > 0: scale_y = image.shape[0] / mask_array.shape[0] scale_x = image.shape[1] / mask_array.shape[1] step = max(1, len(red_coords) // 20) sampled = red_coords[::step][:20] # Максимум 20 точек for y, x in sampled: mask_points.append([x * scale_x, y * scale_y]) mask_labels.append(0) # background if mask_points: # Объединяем с существующими точками if points is not None: points = np.vstack([points, np.array(mask_points)]) labels = np.concatenate([labels, np.array(mask_labels)]) else: points = np.array(mask_points) labels = np.array(mask_labels) logger.info(f"Промпт из маски: {len(mask_points)} точек ({np.sum(np.array(mask_labels) == 1)} foreground, {np.sum(np.array(mask_labels) == 0)} background)") else: logger.warning("Маска пустая или не содержит foreground (зеленых/белых) или background (красных) пикселей") except Exception as e: logger.error(f"Ошибка обработки маски: {e}") raise HTTPException(status_code=400, detail=f"Некорректная маска: {str(e)}") # Делаем предсказание с промптами if points is not None or box is not None: logger.info(f"Используем промпты: points={points is not None}, box={box is not None}") # Если много точек (>10), используем single mask для стабильности # Если мало точек или только box, используем multimask для вариативности use_multimask = True if points is not None and len(points) > 10: use_multimask = False logger.info("Много точек, используем single mask mode для стабильности") masks, scores, logits = predictor.predict( point_coords=points, point_labels=labels, box=box, multimask_output=use_multimask, ) # Если multimask, выбираем лучшую по score if use_multimask and len(masks) > 1: best_idx = np.argmax(scores) masks = masks[best_idx:best_idx+1] scores = scores[best_idx:best_idx+1] logger.info(f"Выбрана маска {best_idx} с confidence {scores[0]:.3f}") else: # Автоматическая сегментация - берем центральную точку logger.info("Промпты не указаны, сегментирую центральный объект") h, w = image.shape[:2] point = np.array([[w // 2, h // 2]]) label = np.array([1]) masks, scores, logits = predictor.predict( point_coords=point, point_labels=label, multimask_output=True, ) # Конвертируем маски в координаты (с контурами если нужно) segments = masks_to_coords(masks, include_contours=include_masks) logger.info(f"Найдено сегментов: {len(segments)}, масок: {len(masks)}") logger.info(f"extract_objects = {extract_objects}") # Добавляем confidence scores for i, seg in enumerate(segments): seg["confidence"] = float(scores[i]) if i < len(scores) else 0.0 # Если нужно - вырезаем объект и добавляем base64 logger.info(f"Обрабатываю сегмент {i}: extract_objects={extract_objects}, i < len(masks) = {i < len(masks)}") if extract_objects and i < len(masks): logger.info(f"Вырезаю объект {i}...") seg["extracted_image"] = extract_object_image(image, masks[i]) logger.info(f"✓ Вырезан объект {i}, размер маски: {masks[i].sum()} пикселей") else: logger.warning(f"❌ Пропускаю объект {i}: extract_objects={extract_objects}") result = { "success": True, "image_size": { "width": int(image.shape[1]), "height": int(image.shape[0]) }, "segments_count": len(segments), "segments": segments } # Конвертируем все numpy типы в нативные Python типы return convert_to_native_types(result) except Exception as e: logger.error(f"Ошибка при сегментации: {e}") raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}") @app.post("/segment/auto") async def segment_auto( file: UploadFile = File(...), points_per_side: int = Query(32, description="Количество точек на сторону для автосегментации"), include_masks: bool = Query(True, description="Включить контуры масок в ответ"), ): """ Автоматическая сегментация всех объектов на изображении. Использует grid of points для поиска всех возможных объектов. Если include_masks=True, возвращает контуры масок для точной отрисовки. """ if predictor is None: raise HTTPException(status_code=503, detail="Модель не загружена") try: image_bytes = await file.read() image = process_image(image_bytes) logger.info(f"Автосегментация изображения: {image.shape}") predictor.set_image(image) # Создаем сетку точек h, w = image.shape[:2] x_coords = np.linspace(0, w, points_per_side) y_coords = np.linspace(0, h, points_per_side) all_segments = [] segment_id = 0 # Для каждой точки в сетке пытаемся найти объект for y in y_coords: for x in x_coords: point = np.array([[x, y]]) label = np.array([1]) masks, scores, _ = predictor.predict( point_coords=point, point_labels=label, multimask_output=False, ) if masks.shape[0] > 0 and scores[0] > 0.5: # Порог confidence segments = masks_to_coords(masks, include_contours=include_masks) for seg in segments: seg["segment_id"] = segment_id seg["confidence"] = float(scores[0]) all_segments.append(seg) segment_id += 1 # Убираем дубликаты (примерно) # Два сегмента считаем дубликатами если их центры близко unique_segments = [] for seg in all_segments: is_duplicate = False for unique_seg in unique_segments: dx = seg["center"]["x"] - unique_seg["center"]["x"] dy = seg["center"]["y"] - unique_seg["center"]["y"] dist = (dx**2 + dy**2) ** 0.5 if dist < 50: # Порог расстояния между центрами is_duplicate = True break if not is_duplicate: unique_segments.append(seg) result = { "success": True, "image_size": { "width": int(image.shape[1]), "height": int(image.shape[0]) }, "segments_count": len(unique_segments), "segments": unique_segments } # Конвертируем все numpy типы в нативные Python типы return convert_to_native_types(result) except Exception as e: logger.error(f"Ошибка при автосегментации: {e}") raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}") @app.post("/segment/batch", response_model=BatchSegmentResponse) async def segment_batch(request: BatchSegmentRequest = Body(...)): """ Батчинг сегментация нескольких объектов. Принимает изображение и массив промптов (mask/box/points). Обрабатывает каждый selected промпт отдельно. Возвращает массив результатов с метаданными. Идеально для: - Множественных объектов - Мобильных приложений - Когда фронт уже разделил объекты """ if predictor is None: raise HTTPException(status_code=503, detail="Модель не загружена, перезапусти сервер") try: # Декодируем изображение из base64 image_data = request.image if ',' in image_data: image_data = image_data.split(',')[1] # Убираем data:image/...;base64, image_bytes = base64.b64decode(image_data) image = process_image(image_bytes) logger.info(f"Батчинг сегментация: {image.shape}, промптов: {len(request.prompts)}") # Устанавливаем изображение один раз predictor.set_image(image) results = [] # Фильтруем только selected промпты selected_prompts = [p for p in request.prompts if p.selected] logger.info(f"Обрабатываем {len(selected_prompts)} из {len(request.prompts)} промптов") # Обрабатываем каждый промпт отдельно for prompt in selected_prompts: logger.info(f"Обрабатываю промпт #{prompt.id}, тип: {prompt.type}, label: {prompt.label}") try: # Подготавливаем промпт в зависимости от типа points = None labels = None box = None if prompt.type == "mask": # Декодируем маску и извлекаем точки mask_data = prompt.data if ',' in mask_data: mask_data = mask_data.split(',')[1] mask_bytes = base64.b64decode(mask_data) mask_image = Image.open(io.BytesIO(mask_bytes)).convert('RGBA') mask_array = np.array(mask_image) # Извлекаем foreground и background пиксели # Поддерживаем несколько форматов: # 1. Зеленый (R<100, G>150, B<100) - классический foreground # 2. Белый/светлый (R>200, G>200, B>200) - часто используется фронтами # 3. Красный (R>150, G<100, B<100) - background green_mask = (mask_array[:, :, 0] < 100) & (mask_array[:, :, 1] > 150) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0) white_mask = (mask_array[:, :, 0] > 200) & (mask_array[:, :, 1] > 200) & (mask_array[:, :, 2] > 200) & (mask_array[:, :, 3] > 0) red_mask = (mask_array[:, :, 0] > 150) & (mask_array[:, :, 1] < 100) & (mask_array[:, :, 2] < 100) & (mask_array[:, :, 3] > 0) # Объединяем зеленые и белые как foreground foreground_mask = green_mask | white_mask mask_points = [] mask_labels = [] # Foreground точки (зеленые + белые) foreground_coords = np.argwhere(foreground_mask) if len(foreground_coords) > 0: scale_y = image.shape[0] / mask_array.shape[0] scale_x = image.shape[1] / mask_array.shape[1] step = max(1, len(foreground_coords) // 20) sampled = foreground_coords[::step][:20] for y, x in sampled: mask_points.append([x * scale_x, y * scale_y]) mask_labels.append(1) # Background точки red_coords = np.argwhere(red_mask) if len(red_coords) > 0: scale_y = image.shape[0] / mask_array.shape[0] scale_x = image.shape[1] / mask_array.shape[1] step = max(1, len(red_coords) // 20) sampled = red_coords[::step][:20] for y, x in sampled: mask_points.append([x * scale_x, y * scale_y]) mask_labels.append(0) if mask_points: points = np.array(mask_points) labels = np.array(mask_labels) elif prompt.type == "box": # Парсим bbox - может быть нормализованный (0-1) или пиксельный bbox_data = prompt.bbox if prompt.bbox else None if bbox_data: x1 = bbox_data.x_min y1 = bbox_data.y_min x2 = bbox_data.x_max y2 = bbox_data.y_max # Если нормализованные координаты (0-1), конвертируем в пиксели if x2 <= 1.0 and y2 <= 1.0: x1 *= image.shape[1] x2 *= image.shape[1] y1 *= image.shape[0] y2 *= image.shape[0] box = np.array([x1, y1, x2, y2]) elif prompt.type == "points": # Ожидаем JSON в формате [[x, y, label], ...] import json points_data = json.loads(prompt.data) points_list = [] labels_list = [] for point in points_data: x, y = point[0], point[1] label = point[2] if len(point) > 2 else 1 # Если нормализованные, конвертируем if x <= 1.0 and y <= 1.0: x *= image.shape[1] y *= image.shape[0] points_list.append([x, y]) labels_list.append(label) points = np.array(points_list) labels = np.array(labels_list) # Делаем предсказание if points is not None or box is not None: # Решаем использовать ли multimask use_multimask = True if points is not None and len(points) > 10: use_multimask = False masks, scores, logits = predictor.predict( point_coords=points, point_labels=labels, box=box, multimask_output=use_multimask, ) # Если multimask, выбираем лучшую if use_multimask and len(masks) > 1: best_idx = np.argmax(scores) masks = masks[best_idx:best_idx+1] scores = scores[best_idx:best_idx+1] # Берем первую маску mask = masks[0] score = float(scores[0]) # Очищаем маску если нужно if request.options.clean_masks: mask = clean_mask(mask, min_area=100) # Вычисляем метрики y_coords, x_coords = np.where(mask > 0) if len(x_coords) > 0: x_min, x_max = int(x_coords.min()), int(x_coords.max()) y_min, y_max = int(y_coords.min()), int(y_coords.max()) area = int(mask.sum()) center_x = int(x_coords.mean()) center_y = int(y_coords.mean()) # Формируем результат result = { "id": prompt.id, "label": prompt.label, "bbox": { "x_min": x_min, "y_min": y_min, "x_max": x_max, "y_max": y_max, "width": x_max - x_min, "height": y_max - y_min }, "area": area, "center": { "x": center_x, "y": center_y }, "confidence": score } # Добавляем вырезанный объект если нужно if request.options.extract_objects: result["extracted_image"] = extract_object_image( image, mask, clean=request.options.clean_masks ) # Добавляем контуры если нужно if request.options.include_masks: segments = masks_to_coords(masks, include_contours=True) if segments: result["contours"] = segments[0].get("contours", []) result["mask_rle"] = segments[0].get("mask_rle", {}) results.append(result) logger.info(f"✓ Промпт #{prompt.id} обработан, confidence: {score:.3f}") else: logger.warning(f"✗ Промпт #{prompt.id} не дал результата") else: logger.warning(f"✗ Промпт #{prompt.id}: нет данных для сегментации") except Exception as e: logger.error(f"✗ Ошибка обработки промпта #{prompt.id}: {e}") # Продолжаем обработку остальных промптов continue response = { "success": True, "image_size": { "width": int(image.shape[1]), "height": int(image.shape[0]) }, "results": results } logger.info(f"Батчинг завершен: обработано {len(results)} объектов") # Сохраняем лог запроса для аудита (только метаданные, без изображений) try: request_dict = request.dict() save_batch_request_log(request_dict, response, image.shape[1], image.shape[0]) except Exception as e: logger.warning(f"Не удалось сохранить лог запроса: {e}") return convert_to_native_types(response) except Exception as e: logger.error(f"Ошибка при батчинг сегментации: {e}") raise HTTPException(status_code=500, detail=f"Ошибка обработки: {str(e)}") if __name__ == "__main__": import uvicorn import os # Порт из переменной окружения (для HF Spaces) или 8000 по умолчанию port = int(os.getenv("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)