Multicentury-HTR-Demo / segment_image.py
MikkoLipsanen's picture
Update segmentation to use rfdetr model
c99fa8d verified
from typing import List, Tuple, Optional, Dict, Any
from shapely.validation import make_valid
from shapely.geometry import Polygon
from rfdetr import RFDETRSegPreview
from collections import defaultdict
import numpy as np
import cv2
import os
from image_processing import (
load_with_torchvision,
preprocess_resize_torch_transform,
upscale_bbox,
upscale_mask_opencv,
crop_line
)
from utils import get_default_region, get_line_regions, order_regions_lines
class SegmentImage:
"""
Document image segmentation for detecting text regions and lines.
Uses an RFDETR segmentation model to detect and extract text regions and lines
from document images. Includes polygon merging, validation, and ordering.
Args:
model_path: Path to the RFDETR segmentation model weights
max_size: Maximum dimension (height or width) for image preprocessing (default: 768)
confidence_threshold: Minimum confidence score for detections (default: 0.15, range: 0-1)
line_percentage_threshold: Minimum polygon area as fraction of image area for lines
(default: 7e-05, i.e., 0.007% of image)
region_percentage_threshold: Minimum polygon area as fraction of image area for regions
(default: 7e-05, i.e., 0.007% of image)
line_iou: IoU threshold for merging overlapping line polygons (default: 0.3, range: 0-1)
region_iou: IoU threshold for merging overlapping region polygons (default: 0.3, range: 0-1)
line_overlap_threshold: Area overlap ratio threshold for merging lines (default: 0.5, range: 0-1)
region_overlap_threshold: Area overlap ratio threshold for merging regions (default: 0.5, range: 0-1)
class_id_region: Class ID constant for identifying regions in segmentation model output
class_id_line: Class ID constant for identifying lines in segmentation model output
min_polygon_points: Minimum number of points to form a valid polygon
"""
def __init__(self,
model_path: str,
max_size: int = 768,
confidence_threshold: float = 0.15,
line_percentage_threshold: float = 7e-05,
region_percentage_threshold: float = 7e-05,
line_iou: float = 0.3,
region_iou: float = 0.3,
line_overlap_threshold: float = 0.5,
region_overlap_threshold: float = 0.5,
class_id_region: int = 1,
class_id_line: int = 2,
min_polygon_points: int = 3):
self.model_path = model_path
self.max_size = max_size
self.confidence_threshold = confidence_threshold
self.line_percentage_threshold = line_percentage_threshold
self.region_percentage_threshold = region_percentage_threshold
self.line_iou = line_iou
self.region_iou = region_iou
self.line_overlap_threshold = line_overlap_threshold
self.region_overlap_threshold = region_overlap_threshold
self.class_id_region = class_id_region
self.class_id_line = class_id_line
self.min_polygon_points = min_polygon_points
# Validate model path
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model path does not exist: {self.model_path}")
self.init_model()
def init_model(self) -> None:
"""
Load and optimize an RFDETR segmentation model for inference.
Raises:
Exception: If model initialization fails
"""
try:
self.model = RFDETRSegPreview(pretrain_weights=self.model_path)
self.model.optimize_for_inference()
print(f"✓ Segmentation model initialized successfully")
except Exception as e:
raise RuntimeError(f'Failed to initialize segmentation model: {e}')
def validate_polygon(self, polygon: np.ndarray) -> Optional[Polygon]:
"""
Test and correct the validity of a polygon using Shapely.
Converts numpy array to Shapely Polygon, validates it, and attempts
to fix invalid geometries using make_valid().
Args:
polygon: Array of polygon coordinates with shape (N, 2)
Returns:
Valid Shapely Polygon object, or None if polygon has fewer than 3 points
"""
if len(polygon) > 2:
try:
shapely_polygon = Polygon(polygon)
if not shapely_polygon.is_valid:
shapely_polygon = make_valid(shapely_polygon)
return shapely_polygon
except Exception as e:
print(f"Warning: Failed to validate polygon: {e}")
return None
else:
return None
def merge_polygons(self,
polygons: List[np.ndarray],
polygon_iou: float,
overlap_threshold: float) -> Tuple[List[np.ndarray], List[int]]:
"""
Merge overlapping polygons using connected components (union-find algorithm).
Uses IoU (Intersection over Union) and area overlap ratio to determine which
polygons should be merged. Implements union-find to group connected components
of overlapping polygons, then merges each component into a single polygon.
Args:
polygons: List of polygon coordinate arrays, each with shape (N, 2)
polygon_iou: IoU threshold for merging (0-1)
overlap_threshold: Minimum area overlap ratio for merging (0-1)
Returns:
Tuple of:
- merged_polygons: List of merged polygon coordinate arrays
- polygon_mapping: List mapping each input polygon index to its output
polygon index (-1 if invalid/skipped)
"""
n = len(polygons)
if n == 0:
return [], []
# Validate all polygons
validated = [self.validate_polygon(p) for p in polygons]
# Build adjacency graph of overlapping polygons
parent = list(range(n))
def find(x: int) -> int:
"""Find root of element x with path compression."""
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(x: int, y: int) -> None:
"""Union two sets containing x and y."""
px, py = find(x), find(y)
if px != py:
parent[px] = py
# Build adjacency graph by checking all pairs for overlap
for i in range(n):
poly1 = validated[i]
if not poly1:
continue
for j in range(i + 1, n):
poly2 = validated[j]
if not poly2 or not poly1.intersects(poly2):
continue
# Calculate intersection and union for IoU
intersection = poly1.intersection(poly2)
union_geom = poly1.union(poly2)
iou = intersection.area / union_geom.area if union_geom.area > 0 else 0
# Check merge criteria
should_merge = iou > polygon_iou
# If IoU threshold not met, check area overlap ratio
if not should_merge and overlap_threshold > 0:
smaller_area = min(poly1.area, poly2.area)
overlap_ratio = intersection.area / smaller_area if smaller_area > 0 else 0
should_merge = overlap_ratio > overlap_threshold
# Merge polygons by updating union-find structure
if should_merge:
union(i, j)
# Group polygons by their connected component
components = defaultdict(list)
for i in range(n):
if validated[i]:
root = find(i)
components[root].append(i)
# Merge each connected component
merged_polygons = []
polygon_mapping = [-1] * n # -1 indicates invalid/unmapped polygon
for root, indices in components.items():
output_idx = len(merged_polygons)
if len(indices) == 1:
# Single polygon, no merging needed
idx = indices[0]
merged_polygons.append(polygons[idx])
polygon_mapping[idx] = output_idx
else:
# Merge all polygons in this component using Shapely union
merged = validated[indices[0]]
for idx in indices[1:]:
merged = merged.union(validated[idx])
# Extract polygon coordinates from merged geometry
if merged.geom_type == 'Polygon':
# Single polygon result
merged_polygons.append(
np.array(merged.exterior.coords).astype(np.int32)
)
for idx in indices:
polygon_mapping[idx] = output_idx
elif merged.geom_type in ['MultiPolygon', 'GeometryCollection']:
# Multiple polygons resulted from merge (e.g., touching at single point)
for geom in merged.geoms:
if geom.geom_type == 'Polygon':
merged_polygons.append(
np.array(geom.exterior.coords).astype(np.int32)
)
# Map all source polygons to first output polygon
for idx in indices:
polygon_mapping[idx] = output_idx
return merged_polygons, polygon_mapping
def calculate_polygon_area(self, vertices: np.ndarray) -> float:
"""
Calculate polygon area using the Shoelace formula (surveyor's formula).
Computes area using coordinate cross products. Works for simple polygons
(non-self-intersecting) regardless of vertex ordering.
Args:
vertices: Array of polygon coordinates with shape (N, 2)
Returns:
Area of the polygon in square pixels
"""
x = vertices[:, 0]
y = vertices[:, 1]
# Shoelace formula implementation using array operations
area = 0.5 * np.abs(np.sum(x[:-1] * y[1:]) - np.sum(y[:-1] * x[1:]) + x[-1] * y[0] - y[-1] * x[0])
return area
def mask_to_polygon_cv2(self,
mask: np.ndarray,
original_shape: Tuple[int, int]) -> Tuple[List[np.ndarray], np.ndarray]:
"""
Convert binary segmentation mask to polygon coordinates using OpenCV contours.
Extracts contours from mask, converts them to polygons, and scales coordinates
back to original image dimensions. Also calculates area percentages for filtering.
Args:
mask: Binary mask as numpy array (bool or uint8, 0-255)
original_shape: Tuple of (height, width) of original image
Returns:
Tuple of:
- scaled_polygons: List of polygon coordinate arrays scaled to original size
- area_percentages: Array of polygon areas as fraction of mask size
"""
# Ensure mask is uint8
if mask.dtype == bool:
mask_uint8 = mask.astype(np.uint8) * 255
else:
mask_uint8 = mask.astype(np.uint8)
# Find external contours (only outer boundaries)
contours, _ = cv2.findContours(
mask_uint8,
cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE
)
# Convert contours to polygons (filter out degenerate contours)
polygons = [
contour.squeeze()
for contour in contours
if len(contour) >= self.min_polygon_points
]
# Calculate scaling factors from mask to original image
orig_height, orig_width = original_shape
mask_height, mask_width = mask.shape[:2]
scale_x = orig_width / mask_width
scale_y = orig_height / mask_height
# Scale polygons and calculate areas
scaled_polygons = []
area_percentages = []
mask_area = mask_height * mask_width
for poly in polygons:
# Calculate area on mask coordinates (before scaling)
area = self.calculate_polygon_area(
poly if len(poly.shape) > 1 else poly.reshape(1, -1)
)
area_percentage = area / mask_area if mask_area > 0 else 0
area_percentages.append(area_percentage)
# Scale polygon coordinates to original image size
if len(poly.shape) == 1: # Single point edge case
scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int)
else: # Normal case with multiple points
scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int)
scaled_polygons.append(scaled_poly)
return scaled_polygons, np.array(area_percentages)
def process_polygons(self,
poly_masks: np.ndarray,
image_shape: Tuple[int, int],
percentage_threshold: float,
overlap_threshold: float,
iou_threshold: float) -> Tuple[List[np.ndarray], List[Tuple[int, int, int, int]]]:
"""
Extract polygons from segmentation masks, filter by area, and merge overlapping ones.
Converts masks to polygons, filters out small detections based on area percentage,
and merges overlapping polygons based on IoU and overlap criteria.
Args:
poly_masks: Array of binary segmentation masks from model
image_shape: Tuple of (height, width) of original image
percentage_threshold: Minimum polygon area as fraction of image
overlap_threshold: Minimum overlap ratio for merging polygons
iou_threshold: Minimum IoU for merging polygons
Returns:
Tuple of:
- merged_polygons: List of polygon coordinate arrays
- merged_max_mins: List of bounding boxes as (xmin, ymin, xmax, ymax) tuples
"""
all_polygons = []
all_area_percentages = []
# Extract polygons from all masks
for mask in poly_masks:
polygons, area_percentages = self.mask_to_polygon_cv2(
mask=mask,
original_shape=image_shape
)
all_polygons.extend(polygons)
all_area_percentages.extend(area_percentages)
all_area_percentages = np.array(all_area_percentages)
# Filter polygons by minimum area threshold
if len(all_area_percentages) == 0:
return [], []
valid_indices = np.where(all_area_percentages > percentage_threshold)[0]
filtered_polygons = [all_polygons[idx] for idx in valid_indices]
if not filtered_polygons:
return [], []
# Merge overlapping polygons
merged_polygons, _ = self.merge_polygons(
filtered_polygons,
iou_threshold,
overlap_threshold
)
# Calculate bounding boxes for merged polygons
merged_max_mins = []
for poly in merged_polygons:
if len(poly) > 0:
xmax, ymax = np.max(poly, axis=0)
xmin, ymin = np.min(poly, axis=0)
merged_max_mins.append((xmin, ymin, xmax, ymax))
return merged_polygons, merged_max_mins
def get_segmentation(self, image) -> Optional[List[Dict[str, Any]]]:
"""
Detect and extract ordered text lines and regions from a document image.
Runs the segmentation model on the image, extracts line and region polygons,
merges overlapping detections, associates lines with regions, and orders them
for reading sequence.
Args:
image: PIL Image object in any mode (will be converted to RGB)
Returns:
List of ordered line dictionaries with region associations, or None if
no lines were detected. Each line dict contains coordinates, region ID,
and other metadata.
"""
image_shape = (image.shape[0], image.shape[1])
# Preprocess image (resize for model input)
preprocessed_image = preprocess_resize_torch_transform(
image,
max_size=self.max_size
)
# Run segmentation model
try:
detections = self.model.predict(
preprocessed_image,
threshold=self.confidence_threshold
)
except Exception as e:
print(f"Error during segmentation prediction: {e}")
return None
# Separate line and region masks by class ID
line_mask = detections.mask[detections.class_id == self.class_id_line]
region_mask = detections.mask[detections.class_id == self.class_id_region]
# Process line polygons
merged_line_polygons, merged_line_max_mins = self.process_polygons(
line_mask,
image_shape,
self.line_percentage_threshold,
self.line_overlap_threshold,
self.line_iou
)
# Process region polygons
merged_region_polygons, merged_region_max_mins = self.process_polygons(
region_mask,
image_shape,
self.region_percentage_threshold,
self.region_overlap_threshold,
self.region_iou
)
# If no lines detected, return None
if not merged_line_polygons:
print('No text lines detected from image.')
return None
# Prepare line predictions dictionary
line_preds = {
'coords': merged_line_polygons,
'max_min': merged_line_max_mins
}
# Prepare region predictions (or use default if none detected)
if merged_region_polygons:
region_preds = []
for num, (region_polygon, region_max_min) in enumerate(
zip(merged_region_polygons, merged_region_max_mins)
):
region_preds.append({
'coords': region_polygon,
'id': str(num),
'max_min': region_max_min,
'name': 'paragraph',
'img_shape': image_shape
})
else:
# No regions detected, create default region covering entire image
region_preds = get_default_region(image_shape=image_shape)
# Associate lines with their containing regions
lines_connected_to_regions = get_line_regions(
lines=line_preds,
regions=region_preds
)
# Order lines within regions for proper reading sequence
ordered_lines = order_regions_lines(
lines=lines_connected_to_regions,
regions=region_preds
)
return ordered_lines