Spaces:
Runtime error
Runtime error
| from typing import List, Tuple, Optional, Dict, Any | |
| from shapely.validation import make_valid | |
| from shapely.geometry import Polygon | |
| from rfdetr import RFDETRSegPreview | |
| from collections import defaultdict | |
| import numpy as np | |
| import cv2 | |
| import os | |
| from image_processing import ( | |
| load_with_torchvision, | |
| preprocess_resize_torch_transform, | |
| upscale_bbox, | |
| upscale_mask_opencv, | |
| crop_line | |
| ) | |
| from utils import get_default_region, get_line_regions, order_regions_lines | |
| class SegmentImage: | |
| """ | |
| Document image segmentation for detecting text regions and lines. | |
| Uses an RFDETR segmentation model to detect and extract text regions and lines | |
| from document images. Includes polygon merging, validation, and ordering. | |
| Args: | |
| model_path: Path to the RFDETR segmentation model weights | |
| max_size: Maximum dimension (height or width) for image preprocessing (default: 768) | |
| confidence_threshold: Minimum confidence score for detections (default: 0.15, range: 0-1) | |
| line_percentage_threshold: Minimum polygon area as fraction of image area for lines | |
| (default: 7e-05, i.e., 0.007% of image) | |
| region_percentage_threshold: Minimum polygon area as fraction of image area for regions | |
| (default: 7e-05, i.e., 0.007% of image) | |
| line_iou: IoU threshold for merging overlapping line polygons (default: 0.3, range: 0-1) | |
| region_iou: IoU threshold for merging overlapping region polygons (default: 0.3, range: 0-1) | |
| line_overlap_threshold: Area overlap ratio threshold for merging lines (default: 0.5, range: 0-1) | |
| region_overlap_threshold: Area overlap ratio threshold for merging regions (default: 0.5, range: 0-1) | |
| class_id_region: Class ID constant for identifying regions in segmentation model output | |
| class_id_line: Class ID constant for identifying lines in segmentation model output | |
| min_polygon_points: Minimum number of points to form a valid polygon | |
| """ | |
| def __init__(self, | |
| model_path: str, | |
| max_size: int = 768, | |
| confidence_threshold: float = 0.15, | |
| line_percentage_threshold: float = 7e-05, | |
| region_percentage_threshold: float = 7e-05, | |
| line_iou: float = 0.3, | |
| region_iou: float = 0.3, | |
| line_overlap_threshold: float = 0.5, | |
| region_overlap_threshold: float = 0.5, | |
| class_id_region: int = 1, | |
| class_id_line: int = 2, | |
| min_polygon_points: int = 3): | |
| self.model_path = model_path | |
| self.max_size = max_size | |
| self.confidence_threshold = confidence_threshold | |
| self.line_percentage_threshold = line_percentage_threshold | |
| self.region_percentage_threshold = region_percentage_threshold | |
| self.line_iou = line_iou | |
| self.region_iou = region_iou | |
| self.line_overlap_threshold = line_overlap_threshold | |
| self.region_overlap_threshold = region_overlap_threshold | |
| self.class_id_region = class_id_region | |
| self.class_id_line = class_id_line | |
| self.min_polygon_points = min_polygon_points | |
| # Validate model path | |
| if not os.path.exists(self.model_path): | |
| raise FileNotFoundError(f"Model path does not exist: {self.model_path}") | |
| self.init_model() | |
| def init_model(self) -> None: | |
| """ | |
| Load and optimize an RFDETR segmentation model for inference. | |
| Raises: | |
| Exception: If model initialization fails | |
| """ | |
| try: | |
| self.model = RFDETRSegPreview(pretrain_weights=self.model_path) | |
| self.model.optimize_for_inference() | |
| print(f"✓ Segmentation model initialized successfully") | |
| except Exception as e: | |
| raise RuntimeError(f'Failed to initialize segmentation model: {e}') | |
| def validate_polygon(self, polygon: np.ndarray) -> Optional[Polygon]: | |
| """ | |
| Test and correct the validity of a polygon using Shapely. | |
| Converts numpy array to Shapely Polygon, validates it, and attempts | |
| to fix invalid geometries using make_valid(). | |
| Args: | |
| polygon: Array of polygon coordinates with shape (N, 2) | |
| Returns: | |
| Valid Shapely Polygon object, or None if polygon has fewer than 3 points | |
| """ | |
| if len(polygon) > 2: | |
| try: | |
| shapely_polygon = Polygon(polygon) | |
| if not shapely_polygon.is_valid: | |
| shapely_polygon = make_valid(shapely_polygon) | |
| return shapely_polygon | |
| except Exception as e: | |
| print(f"Warning: Failed to validate polygon: {e}") | |
| return None | |
| else: | |
| return None | |
| def merge_polygons(self, | |
| polygons: List[np.ndarray], | |
| polygon_iou: float, | |
| overlap_threshold: float) -> Tuple[List[np.ndarray], List[int]]: | |
| """ | |
| Merge overlapping polygons using connected components (union-find algorithm). | |
| Uses IoU (Intersection over Union) and area overlap ratio to determine which | |
| polygons should be merged. Implements union-find to group connected components | |
| of overlapping polygons, then merges each component into a single polygon. | |
| Args: | |
| polygons: List of polygon coordinate arrays, each with shape (N, 2) | |
| polygon_iou: IoU threshold for merging (0-1) | |
| overlap_threshold: Minimum area overlap ratio for merging (0-1) | |
| Returns: | |
| Tuple of: | |
| - merged_polygons: List of merged polygon coordinate arrays | |
| - polygon_mapping: List mapping each input polygon index to its output | |
| polygon index (-1 if invalid/skipped) | |
| """ | |
| n = len(polygons) | |
| if n == 0: | |
| return [], [] | |
| # Validate all polygons | |
| validated = [self.validate_polygon(p) for p in polygons] | |
| # Build adjacency graph of overlapping polygons | |
| parent = list(range(n)) | |
| def find(x: int) -> int: | |
| """Find root of element x with path compression.""" | |
| if parent[x] != x: | |
| parent[x] = find(parent[x]) | |
| return parent[x] | |
| def union(x: int, y: int) -> None: | |
| """Union two sets containing x and y.""" | |
| px, py = find(x), find(y) | |
| if px != py: | |
| parent[px] = py | |
| # Build adjacency graph by checking all pairs for overlap | |
| for i in range(n): | |
| poly1 = validated[i] | |
| if not poly1: | |
| continue | |
| for j in range(i + 1, n): | |
| poly2 = validated[j] | |
| if not poly2 or not poly1.intersects(poly2): | |
| continue | |
| # Calculate intersection and union for IoU | |
| intersection = poly1.intersection(poly2) | |
| union_geom = poly1.union(poly2) | |
| iou = intersection.area / union_geom.area if union_geom.area > 0 else 0 | |
| # Check merge criteria | |
| should_merge = iou > polygon_iou | |
| # If IoU threshold not met, check area overlap ratio | |
| if not should_merge and overlap_threshold > 0: | |
| smaller_area = min(poly1.area, poly2.area) | |
| overlap_ratio = intersection.area / smaller_area if smaller_area > 0 else 0 | |
| should_merge = overlap_ratio > overlap_threshold | |
| # Merge polygons by updating union-find structure | |
| if should_merge: | |
| union(i, j) | |
| # Group polygons by their connected component | |
| components = defaultdict(list) | |
| for i in range(n): | |
| if validated[i]: | |
| root = find(i) | |
| components[root].append(i) | |
| # Merge each connected component | |
| merged_polygons = [] | |
| polygon_mapping = [-1] * n # -1 indicates invalid/unmapped polygon | |
| for root, indices in components.items(): | |
| output_idx = len(merged_polygons) | |
| if len(indices) == 1: | |
| # Single polygon, no merging needed | |
| idx = indices[0] | |
| merged_polygons.append(polygons[idx]) | |
| polygon_mapping[idx] = output_idx | |
| else: | |
| # Merge all polygons in this component using Shapely union | |
| merged = validated[indices[0]] | |
| for idx in indices[1:]: | |
| merged = merged.union(validated[idx]) | |
| # Extract polygon coordinates from merged geometry | |
| if merged.geom_type == 'Polygon': | |
| # Single polygon result | |
| merged_polygons.append( | |
| np.array(merged.exterior.coords).astype(np.int32) | |
| ) | |
| for idx in indices: | |
| polygon_mapping[idx] = output_idx | |
| elif merged.geom_type in ['MultiPolygon', 'GeometryCollection']: | |
| # Multiple polygons resulted from merge (e.g., touching at single point) | |
| for geom in merged.geoms: | |
| if geom.geom_type == 'Polygon': | |
| merged_polygons.append( | |
| np.array(geom.exterior.coords).astype(np.int32) | |
| ) | |
| # Map all source polygons to first output polygon | |
| for idx in indices: | |
| polygon_mapping[idx] = output_idx | |
| return merged_polygons, polygon_mapping | |
| def calculate_polygon_area(self, vertices: np.ndarray) -> float: | |
| """ | |
| Calculate polygon area using the Shoelace formula (surveyor's formula). | |
| Computes area using coordinate cross products. Works for simple polygons | |
| (non-self-intersecting) regardless of vertex ordering. | |
| Args: | |
| vertices: Array of polygon coordinates with shape (N, 2) | |
| Returns: | |
| Area of the polygon in square pixels | |
| """ | |
| x = vertices[:, 0] | |
| y = vertices[:, 1] | |
| # Shoelace formula implementation using array operations | |
| area = 0.5 * np.abs(np.sum(x[:-1] * y[1:]) - np.sum(y[:-1] * x[1:]) + x[-1] * y[0] - y[-1] * x[0]) | |
| return area | |
| def mask_to_polygon_cv2(self, | |
| mask: np.ndarray, | |
| original_shape: Tuple[int, int]) -> Tuple[List[np.ndarray], np.ndarray]: | |
| """ | |
| Convert binary segmentation mask to polygon coordinates using OpenCV contours. | |
| Extracts contours from mask, converts them to polygons, and scales coordinates | |
| back to original image dimensions. Also calculates area percentages for filtering. | |
| Args: | |
| mask: Binary mask as numpy array (bool or uint8, 0-255) | |
| original_shape: Tuple of (height, width) of original image | |
| Returns: | |
| Tuple of: | |
| - scaled_polygons: List of polygon coordinate arrays scaled to original size | |
| - area_percentages: Array of polygon areas as fraction of mask size | |
| """ | |
| # Ensure mask is uint8 | |
| if mask.dtype == bool: | |
| mask_uint8 = mask.astype(np.uint8) * 255 | |
| else: | |
| mask_uint8 = mask.astype(np.uint8) | |
| # Find external contours (only outer boundaries) | |
| contours, _ = cv2.findContours( | |
| mask_uint8, | |
| cv2.RETR_EXTERNAL, | |
| cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| # Convert contours to polygons (filter out degenerate contours) | |
| polygons = [ | |
| contour.squeeze() | |
| for contour in contours | |
| if len(contour) >= self.min_polygon_points | |
| ] | |
| # Calculate scaling factors from mask to original image | |
| orig_height, orig_width = original_shape | |
| mask_height, mask_width = mask.shape[:2] | |
| scale_x = orig_width / mask_width | |
| scale_y = orig_height / mask_height | |
| # Scale polygons and calculate areas | |
| scaled_polygons = [] | |
| area_percentages = [] | |
| mask_area = mask_height * mask_width | |
| for poly in polygons: | |
| # Calculate area on mask coordinates (before scaling) | |
| area = self.calculate_polygon_area( | |
| poly if len(poly.shape) > 1 else poly.reshape(1, -1) | |
| ) | |
| area_percentage = area / mask_area if mask_area > 0 else 0 | |
| area_percentages.append(area_percentage) | |
| # Scale polygon coordinates to original image size | |
| if len(poly.shape) == 1: # Single point edge case | |
| scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int) | |
| else: # Normal case with multiple points | |
| scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int) | |
| scaled_polygons.append(scaled_poly) | |
| return scaled_polygons, np.array(area_percentages) | |
| def process_polygons(self, | |
| poly_masks: np.ndarray, | |
| image_shape: Tuple[int, int], | |
| percentage_threshold: float, | |
| overlap_threshold: float, | |
| iou_threshold: float) -> Tuple[List[np.ndarray], List[Tuple[int, int, int, int]]]: | |
| """ | |
| Extract polygons from segmentation masks, filter by area, and merge overlapping ones. | |
| Converts masks to polygons, filters out small detections based on area percentage, | |
| and merges overlapping polygons based on IoU and overlap criteria. | |
| Args: | |
| poly_masks: Array of binary segmentation masks from model | |
| image_shape: Tuple of (height, width) of original image | |
| percentage_threshold: Minimum polygon area as fraction of image | |
| overlap_threshold: Minimum overlap ratio for merging polygons | |
| iou_threshold: Minimum IoU for merging polygons | |
| Returns: | |
| Tuple of: | |
| - merged_polygons: List of polygon coordinate arrays | |
| - merged_max_mins: List of bounding boxes as (xmin, ymin, xmax, ymax) tuples | |
| """ | |
| all_polygons = [] | |
| all_area_percentages = [] | |
| # Extract polygons from all masks | |
| for mask in poly_masks: | |
| polygons, area_percentages = self.mask_to_polygon_cv2( | |
| mask=mask, | |
| original_shape=image_shape | |
| ) | |
| all_polygons.extend(polygons) | |
| all_area_percentages.extend(area_percentages) | |
| all_area_percentages = np.array(all_area_percentages) | |
| # Filter polygons by minimum area threshold | |
| if len(all_area_percentages) == 0: | |
| return [], [] | |
| valid_indices = np.where(all_area_percentages > percentage_threshold)[0] | |
| filtered_polygons = [all_polygons[idx] for idx in valid_indices] | |
| if not filtered_polygons: | |
| return [], [] | |
| # Merge overlapping polygons | |
| merged_polygons, _ = self.merge_polygons( | |
| filtered_polygons, | |
| iou_threshold, | |
| overlap_threshold | |
| ) | |
| # Calculate bounding boxes for merged polygons | |
| merged_max_mins = [] | |
| for poly in merged_polygons: | |
| if len(poly) > 0: | |
| xmax, ymax = np.max(poly, axis=0) | |
| xmin, ymin = np.min(poly, axis=0) | |
| merged_max_mins.append((xmin, ymin, xmax, ymax)) | |
| return merged_polygons, merged_max_mins | |
| def get_segmentation(self, image) -> Optional[List[Dict[str, Any]]]: | |
| """ | |
| Detect and extract ordered text lines and regions from a document image. | |
| Runs the segmentation model on the image, extracts line and region polygons, | |
| merges overlapping detections, associates lines with regions, and orders them | |
| for reading sequence. | |
| Args: | |
| image: PIL Image object in any mode (will be converted to RGB) | |
| Returns: | |
| List of ordered line dictionaries with region associations, or None if | |
| no lines were detected. Each line dict contains coordinates, region ID, | |
| and other metadata. | |
| """ | |
| image_shape = (image.shape[0], image.shape[1]) | |
| # Preprocess image (resize for model input) | |
| preprocessed_image = preprocess_resize_torch_transform( | |
| image, | |
| max_size=self.max_size | |
| ) | |
| # Run segmentation model | |
| try: | |
| detections = self.model.predict( | |
| preprocessed_image, | |
| threshold=self.confidence_threshold | |
| ) | |
| except Exception as e: | |
| print(f"Error during segmentation prediction: {e}") | |
| return None | |
| # Separate line and region masks by class ID | |
| line_mask = detections.mask[detections.class_id == self.class_id_line] | |
| region_mask = detections.mask[detections.class_id == self.class_id_region] | |
| # Process line polygons | |
| merged_line_polygons, merged_line_max_mins = self.process_polygons( | |
| line_mask, | |
| image_shape, | |
| self.line_percentage_threshold, | |
| self.line_overlap_threshold, | |
| self.line_iou | |
| ) | |
| # Process region polygons | |
| merged_region_polygons, merged_region_max_mins = self.process_polygons( | |
| region_mask, | |
| image_shape, | |
| self.region_percentage_threshold, | |
| self.region_overlap_threshold, | |
| self.region_iou | |
| ) | |
| # If no lines detected, return None | |
| if not merged_line_polygons: | |
| print('No text lines detected from image.') | |
| return None | |
| # Prepare line predictions dictionary | |
| line_preds = { | |
| 'coords': merged_line_polygons, | |
| 'max_min': merged_line_max_mins | |
| } | |
| # Prepare region predictions (or use default if none detected) | |
| if merged_region_polygons: | |
| region_preds = [] | |
| for num, (region_polygon, region_max_min) in enumerate( | |
| zip(merged_region_polygons, merged_region_max_mins) | |
| ): | |
| region_preds.append({ | |
| 'coords': region_polygon, | |
| 'id': str(num), | |
| 'max_min': region_max_min, | |
| 'name': 'paragraph', | |
| 'img_shape': image_shape | |
| }) | |
| else: | |
| # No regions detected, create default region covering entire image | |
| region_preds = get_default_region(image_shape=image_shape) | |
| # Associate lines with their containing regions | |
| lines_connected_to_regions = get_line_regions( | |
| lines=line_preds, | |
| regions=region_preds | |
| ) | |
| # Order lines within regions for proper reading sequence | |
| ordered_lines = order_regions_lines( | |
| lines=lines_connected_to_regions, | |
| regions=region_preds | |
| ) | |
| return ordered_lines |