from typing import List, Tuple, Optional, Dict, Any from shapely.validation import make_valid from shapely.geometry import Polygon from rfdetr import RFDETRSegPreview from collections import defaultdict import numpy as np import cv2 import os from image_processing import ( load_with_torchvision, preprocess_resize_torch_transform, upscale_bbox, upscale_mask_opencv, crop_line ) from utils import get_default_region, get_line_regions, order_regions_lines class SegmentImage: """ Document image segmentation for detecting text regions and lines. Uses an RFDETR segmentation model to detect and extract text regions and lines from document images. Includes polygon merging, validation, and ordering. Args: model_path: Path to the RFDETR segmentation model weights max_size: Maximum dimension (height or width) for image preprocessing (default: 768) confidence_threshold: Minimum confidence score for detections (default: 0.15, range: 0-1) line_percentage_threshold: Minimum polygon area as fraction of image area for lines (default: 7e-05, i.e., 0.007% of image) region_percentage_threshold: Minimum polygon area as fraction of image area for regions (default: 7e-05, i.e., 0.007% of image) line_iou: IoU threshold for merging overlapping line polygons (default: 0.3, range: 0-1) region_iou: IoU threshold for merging overlapping region polygons (default: 0.3, range: 0-1) line_overlap_threshold: Area overlap ratio threshold for merging lines (default: 0.5, range: 0-1) region_overlap_threshold: Area overlap ratio threshold for merging regions (default: 0.5, range: 0-1) class_id_region: Class ID constant for identifying regions in segmentation model output class_id_line: Class ID constant for identifying lines in segmentation model output min_polygon_points: Minimum number of points to form a valid polygon """ def __init__(self, model_path: str, max_size: int = 768, confidence_threshold: float = 0.15, line_percentage_threshold: float = 7e-05, region_percentage_threshold: float = 7e-05, line_iou: float = 0.3, region_iou: float = 0.3, line_overlap_threshold: float = 0.5, region_overlap_threshold: float = 0.5, class_id_region: int = 1, class_id_line: int = 2, min_polygon_points: int = 3): self.model_path = model_path self.max_size = max_size self.confidence_threshold = confidence_threshold self.line_percentage_threshold = line_percentage_threshold self.region_percentage_threshold = region_percentage_threshold self.line_iou = line_iou self.region_iou = region_iou self.line_overlap_threshold = line_overlap_threshold self.region_overlap_threshold = region_overlap_threshold self.class_id_region = class_id_region self.class_id_line = class_id_line self.min_polygon_points = min_polygon_points # Validate model path if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model path does not exist: {self.model_path}") self.init_model() def init_model(self) -> None: """ Load and optimize an RFDETR segmentation model for inference. Raises: Exception: If model initialization fails """ try: self.model = RFDETRSegPreview(pretrain_weights=self.model_path) self.model.optimize_for_inference() print(f"✓ Segmentation model initialized successfully") except Exception as e: raise RuntimeError(f'Failed to initialize segmentation model: {e}') def validate_polygon(self, polygon: np.ndarray) -> Optional[Polygon]: """ Test and correct the validity of a polygon using Shapely. Converts numpy array to Shapely Polygon, validates it, and attempts to fix invalid geometries using make_valid(). Args: polygon: Array of polygon coordinates with shape (N, 2) Returns: Valid Shapely Polygon object, or None if polygon has fewer than 3 points """ if len(polygon) > 2: try: shapely_polygon = Polygon(polygon) if not shapely_polygon.is_valid: shapely_polygon = make_valid(shapely_polygon) return shapely_polygon except Exception as e: print(f"Warning: Failed to validate polygon: {e}") return None else: return None def merge_polygons(self, polygons: List[np.ndarray], polygon_iou: float, overlap_threshold: float) -> Tuple[List[np.ndarray], List[int]]: """ Merge overlapping polygons using connected components (union-find algorithm). Uses IoU (Intersection over Union) and area overlap ratio to determine which polygons should be merged. Implements union-find to group connected components of overlapping polygons, then merges each component into a single polygon. Args: polygons: List of polygon coordinate arrays, each with shape (N, 2) polygon_iou: IoU threshold for merging (0-1) overlap_threshold: Minimum area overlap ratio for merging (0-1) Returns: Tuple of: - merged_polygons: List of merged polygon coordinate arrays - polygon_mapping: List mapping each input polygon index to its output polygon index (-1 if invalid/skipped) """ n = len(polygons) if n == 0: return [], [] # Validate all polygons validated = [self.validate_polygon(p) for p in polygons] # Build adjacency graph of overlapping polygons parent = list(range(n)) def find(x: int) -> int: """Find root of element x with path compression.""" if parent[x] != x: parent[x] = find(parent[x]) return parent[x] def union(x: int, y: int) -> None: """Union two sets containing x and y.""" px, py = find(x), find(y) if px != py: parent[px] = py # Build adjacency graph by checking all pairs for overlap for i in range(n): poly1 = validated[i] if not poly1: continue for j in range(i + 1, n): poly2 = validated[j] if not poly2 or not poly1.intersects(poly2): continue # Calculate intersection and union for IoU intersection = poly1.intersection(poly2) union_geom = poly1.union(poly2) iou = intersection.area / union_geom.area if union_geom.area > 0 else 0 # Check merge criteria should_merge = iou > polygon_iou # If IoU threshold not met, check area overlap ratio if not should_merge and overlap_threshold > 0: smaller_area = min(poly1.area, poly2.area) overlap_ratio = intersection.area / smaller_area if smaller_area > 0 else 0 should_merge = overlap_ratio > overlap_threshold # Merge polygons by updating union-find structure if should_merge: union(i, j) # Group polygons by their connected component components = defaultdict(list) for i in range(n): if validated[i]: root = find(i) components[root].append(i) # Merge each connected component merged_polygons = [] polygon_mapping = [-1] * n # -1 indicates invalid/unmapped polygon for root, indices in components.items(): output_idx = len(merged_polygons) if len(indices) == 1: # Single polygon, no merging needed idx = indices[0] merged_polygons.append(polygons[idx]) polygon_mapping[idx] = output_idx else: # Merge all polygons in this component using Shapely union merged = validated[indices[0]] for idx in indices[1:]: merged = merged.union(validated[idx]) # Extract polygon coordinates from merged geometry if merged.geom_type == 'Polygon': # Single polygon result merged_polygons.append( np.array(merged.exterior.coords).astype(np.int32) ) for idx in indices: polygon_mapping[idx] = output_idx elif merged.geom_type in ['MultiPolygon', 'GeometryCollection']: # Multiple polygons resulted from merge (e.g., touching at single point) for geom in merged.geoms: if geom.geom_type == 'Polygon': merged_polygons.append( np.array(geom.exterior.coords).astype(np.int32) ) # Map all source polygons to first output polygon for idx in indices: polygon_mapping[idx] = output_idx return merged_polygons, polygon_mapping def calculate_polygon_area(self, vertices: np.ndarray) -> float: """ Calculate polygon area using the Shoelace formula (surveyor's formula). Computes area using coordinate cross products. Works for simple polygons (non-self-intersecting) regardless of vertex ordering. Args: vertices: Array of polygon coordinates with shape (N, 2) Returns: Area of the polygon in square pixels """ x = vertices[:, 0] y = vertices[:, 1] # Shoelace formula implementation using array operations area = 0.5 * np.abs(np.sum(x[:-1] * y[1:]) - np.sum(y[:-1] * x[1:]) + x[-1] * y[0] - y[-1] * x[0]) return area def mask_to_polygon_cv2(self, mask: np.ndarray, original_shape: Tuple[int, int]) -> Tuple[List[np.ndarray], np.ndarray]: """ Convert binary segmentation mask to polygon coordinates using OpenCV contours. Extracts contours from mask, converts them to polygons, and scales coordinates back to original image dimensions. Also calculates area percentages for filtering. Args: mask: Binary mask as numpy array (bool or uint8, 0-255) original_shape: Tuple of (height, width) of original image Returns: Tuple of: - scaled_polygons: List of polygon coordinate arrays scaled to original size - area_percentages: Array of polygon areas as fraction of mask size """ # Ensure mask is uint8 if mask.dtype == bool: mask_uint8 = mask.astype(np.uint8) * 255 else: mask_uint8 = mask.astype(np.uint8) # Find external contours (only outer boundaries) contours, _ = cv2.findContours( mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) # Convert contours to polygons (filter out degenerate contours) polygons = [ contour.squeeze() for contour in contours if len(contour) >= self.min_polygon_points ] # Calculate scaling factors from mask to original image orig_height, orig_width = original_shape mask_height, mask_width = mask.shape[:2] scale_x = orig_width / mask_width scale_y = orig_height / mask_height # Scale polygons and calculate areas scaled_polygons = [] area_percentages = [] mask_area = mask_height * mask_width for poly in polygons: # Calculate area on mask coordinates (before scaling) area = self.calculate_polygon_area( poly if len(poly.shape) > 1 else poly.reshape(1, -1) ) area_percentage = area / mask_area if mask_area > 0 else 0 area_percentages.append(area_percentage) # Scale polygon coordinates to original image size if len(poly.shape) == 1: # Single point edge case scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int) else: # Normal case with multiple points scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int) scaled_polygons.append(scaled_poly) return scaled_polygons, np.array(area_percentages) def process_polygons(self, poly_masks: np.ndarray, image_shape: Tuple[int, int], percentage_threshold: float, overlap_threshold: float, iou_threshold: float) -> Tuple[List[np.ndarray], List[Tuple[int, int, int, int]]]: """ Extract polygons from segmentation masks, filter by area, and merge overlapping ones. Converts masks to polygons, filters out small detections based on area percentage, and merges overlapping polygons based on IoU and overlap criteria. Args: poly_masks: Array of binary segmentation masks from model image_shape: Tuple of (height, width) of original image percentage_threshold: Minimum polygon area as fraction of image overlap_threshold: Minimum overlap ratio for merging polygons iou_threshold: Minimum IoU for merging polygons Returns: Tuple of: - merged_polygons: List of polygon coordinate arrays - merged_max_mins: List of bounding boxes as (xmin, ymin, xmax, ymax) tuples """ all_polygons = [] all_area_percentages = [] # Extract polygons from all masks for mask in poly_masks: polygons, area_percentages = self.mask_to_polygon_cv2( mask=mask, original_shape=image_shape ) all_polygons.extend(polygons) all_area_percentages.extend(area_percentages) all_area_percentages = np.array(all_area_percentages) # Filter polygons by minimum area threshold if len(all_area_percentages) == 0: return [], [] valid_indices = np.where(all_area_percentages > percentage_threshold)[0] filtered_polygons = [all_polygons[idx] for idx in valid_indices] if not filtered_polygons: return [], [] # Merge overlapping polygons merged_polygons, _ = self.merge_polygons( filtered_polygons, iou_threshold, overlap_threshold ) # Calculate bounding boxes for merged polygons merged_max_mins = [] for poly in merged_polygons: if len(poly) > 0: xmax, ymax = np.max(poly, axis=0) xmin, ymin = np.min(poly, axis=0) merged_max_mins.append((xmin, ymin, xmax, ymax)) return merged_polygons, merged_max_mins def get_segmentation(self, image) -> Optional[List[Dict[str, Any]]]: """ Detect and extract ordered text lines and regions from a document image. Runs the segmentation model on the image, extracts line and region polygons, merges overlapping detections, associates lines with regions, and orders them for reading sequence. Args: image: PIL Image object in any mode (will be converted to RGB) Returns: List of ordered line dictionaries with region associations, or None if no lines were detected. Each line dict contains coordinates, region ID, and other metadata. """ image_shape = (image.shape[0], image.shape[1]) # Preprocess image (resize for model input) preprocessed_image = preprocess_resize_torch_transform( image, max_size=self.max_size ) # Run segmentation model try: detections = self.model.predict( preprocessed_image, threshold=self.confidence_threshold ) except Exception as e: print(f"Error during segmentation prediction: {e}") return None # Separate line and region masks by class ID line_mask = detections.mask[detections.class_id == self.class_id_line] region_mask = detections.mask[detections.class_id == self.class_id_region] # Process line polygons merged_line_polygons, merged_line_max_mins = self.process_polygons( line_mask, image_shape, self.line_percentage_threshold, self.line_overlap_threshold, self.line_iou ) # Process region polygons merged_region_polygons, merged_region_max_mins = self.process_polygons( region_mask, image_shape, self.region_percentage_threshold, self.region_overlap_threshold, self.region_iou ) # If no lines detected, return None if not merged_line_polygons: print('No text lines detected from image.') return None # Prepare line predictions dictionary line_preds = { 'coords': merged_line_polygons, 'max_min': merged_line_max_mins } # Prepare region predictions (or use default if none detected) if merged_region_polygons: region_preds = [] for num, (region_polygon, region_max_min) in enumerate( zip(merged_region_polygons, merged_region_max_mins) ): region_preds.append({ 'coords': region_polygon, 'id': str(num), 'max_min': region_max_min, 'name': 'paragraph', 'img_shape': image_shape }) else: # No regions detected, create default region covering entire image region_preds = get_default_region(image_shape=image_shape) # Associate lines with their containing regions lines_connected_to_regions = get_line_regions( lines=line_preds, regions=region_preds ) # Order lines within regions for proper reading sequence ordered_lines = order_regions_lines( lines=lines_connected_to_regions, regions=region_preds ) return ordered_lines