File size: 19,449 Bytes
c99fa8d
f8a998a
 
c99fa8d
 
f8a998a
c99fa8d
b32227d
f8a998a
c99fa8d
 
 
 
 
 
 
 
 
f8a998a
 
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
 
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
c99fa8d
 
 
f8a998a
c99fa8d
 
 
f8a998a
c99fa8d
f8a998a
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
c99fa8d
 
 
 
 
 
 
 
f8a998a
 
 
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
 
c99fa8d
 
 
f8a998a
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
 
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
c99fa8d
 
 
 
 
 
 
 
 
 
 
 
 
 
f8a998a
c99fa8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
from typing import List, Tuple, Optional, Dict, Any
from shapely.validation import make_valid
from shapely.geometry import Polygon
from rfdetr import RFDETRSegPreview
from collections import defaultdict
import numpy as np
import cv2
import os

from image_processing import (
    load_with_torchvision, 
    preprocess_resize_torch_transform, 
    upscale_bbox, 
    upscale_mask_opencv, 
    crop_line
)

from utils import get_default_region, get_line_regions, order_regions_lines

class SegmentImage:
    """
    Document image segmentation for detecting text regions and lines.
    
    Uses an RFDETR segmentation model to detect and extract text regions and lines
    from document images. Includes polygon merging, validation, and ordering.
    
    Args:
        model_path: Path to the RFDETR segmentation model weights
        max_size: Maximum dimension (height or width) for image preprocessing (default: 768)
        confidence_threshold: Minimum confidence score for detections (default: 0.15, range: 0-1)
        line_percentage_threshold: Minimum polygon area as fraction of image area for lines
                                   (default: 7e-05, i.e., 0.007% of image)
        region_percentage_threshold: Minimum polygon area as fraction of image area for regions
                                     (default: 7e-05, i.e., 0.007% of image)
        line_iou: IoU threshold for merging overlapping line polygons (default: 0.3, range: 0-1)
        region_iou: IoU threshold for merging overlapping region polygons (default: 0.3, range: 0-1)
        line_overlap_threshold: Area overlap ratio threshold for merging lines (default: 0.5, range: 0-1)
        region_overlap_threshold: Area overlap ratio threshold for merging regions (default: 0.5, range: 0-1)
        class_id_region: Class ID constant for identifying regions in segmentation model output
        class_id_line: Class ID constant for identifying lines in segmentation model output
        min_polygon_points: Minimum number of points to form a valid polygon
    """
    def __init__(self, 
                model_path: str,
                max_size: int = 768, 
                confidence_threshold: float = 0.15, 
                line_percentage_threshold: float = 7e-05, 
                region_percentage_threshold: float = 7e-05, 
                line_iou: float = 0.3, 
                region_iou: float = 0.3, 
                line_overlap_threshold: float = 0.5, 
                region_overlap_threshold: float = 0.5,
                class_id_region: int = 1,
                class_id_line: int = 2,
                min_polygon_points: int = 3):

        self.model_path = model_path 
        self.max_size = max_size 
        self.confidence_threshold = confidence_threshold 
        self.line_percentage_threshold = line_percentage_threshold
        self.region_percentage_threshold = region_percentage_threshold
        self.line_iou = line_iou  
        self.region_iou = region_iou 
        self.line_overlap_threshold = line_overlap_threshold
        self.region_overlap_threshold = region_overlap_threshold
        self.class_id_region = class_id_region
        self.class_id_line = class_id_line
        self.min_polygon_points = min_polygon_points
    
        # Validate model path
        if not os.path.exists(self.model_path):
            raise FileNotFoundError(f"Model path does not exist: {self.model_path}")
        
        self.init_model()

    def init_model(self) -> None:
        """
        Load and optimize an RFDETR segmentation model for inference.

        Raises:
            Exception: If model initialization fails
        """
        try:
            self.model = RFDETRSegPreview(pretrain_weights=self.model_path)
            self.model.optimize_for_inference()
            print(f"✓ Segmentation model initialized successfully")
        except Exception as e:
            raise RuntimeError(f'Failed to initialize segmentation model: {e}')
        
    def validate_polygon(self, polygon: np.ndarray) -> Optional[Polygon]:
        """
        Test and correct the validity of a polygon using Shapely.
        
        Converts numpy array to Shapely Polygon, validates it, and attempts
        to fix invalid geometries using make_valid().

        Args:
            polygon: Array of polygon coordinates with shape (N, 2)

        Returns:
            Valid Shapely Polygon object, or None if polygon has fewer than 3 points
        """
        if len(polygon) > 2:
            try:
                shapely_polygon = Polygon(polygon)
                if not shapely_polygon.is_valid:
                    shapely_polygon = make_valid(shapely_polygon)
                return shapely_polygon
            except Exception as e:
                print(f"Warning: Failed to validate polygon: {e}")
                return None
        else:
            return None

    def merge_polygons(self, 
                      polygons: List[np.ndarray], 
                      polygon_iou: float, 
                      overlap_threshold: float) -> Tuple[List[np.ndarray], List[int]]:
        """
        Merge overlapping polygons using connected components (union-find algorithm).
        
        Uses IoU (Intersection over Union) and area overlap ratio to determine which
        polygons should be merged. Implements union-find to group connected components
        of overlapping polygons, then merges each component into a single polygon.
        
        Args:
            polygons: List of polygon coordinate arrays, each with shape (N, 2)
            polygon_iou: IoU threshold for merging (0-1)
            overlap_threshold: Minimum area overlap ratio for merging (0-1)
            
        Returns:
            Tuple of:
                - merged_polygons: List of merged polygon coordinate arrays
                - polygon_mapping: List mapping each input polygon index to its output
                                  polygon index (-1 if invalid/skipped)
        """
        n = len(polygons)
        if n == 0:
            return [], []
        
        # Validate all polygons 
        validated = [self.validate_polygon(p) for p in polygons]
        
        # Build adjacency graph of overlapping polygons
        parent = list(range(n))  
        
        def find(x: int) -> int:
            """Find root of element x with path compression."""
            if parent[x] != x:
                parent[x] = find(parent[x])
            return parent[x]
        
        def union(x: int, y: int) -> None:
            """Union two sets containing x and y."""
            px, py = find(x), find(y)
            if px != py:
                parent[px] = py
        
        # Build adjacency graph by checking all pairs for overlap
        for i in range(n):
            poly1 = validated[i]
            if not poly1:
                continue
                
            for j in range(i + 1, n):
                poly2 = validated[j]
                if not poly2 or not poly1.intersects(poly2):
                    continue

                # Calculate intersection and union for IoU
                intersection = poly1.intersection(poly2)
                union_geom = poly1.union(poly2)
                iou = intersection.area / union_geom.area if union_geom.area > 0 else 0
                
                # Check merge criteria
                should_merge = iou > polygon_iou

                # If IoU threshold not met, check area overlap ratio
                if not should_merge and overlap_threshold > 0:
                    smaller_area = min(poly1.area, poly2.area)
                    overlap_ratio = intersection.area / smaller_area if smaller_area > 0 else 0
                    should_merge = overlap_ratio > overlap_threshold

                # Merge polygons by updating union-find structure
                if should_merge:
                    union(i, j)
        
        # Group polygons by their connected component
        components = defaultdict(list)
        for i in range(n):
            if validated[i]: 
                root = find(i)
                components[root].append(i)
        
        # Merge each connected component
        merged_polygons = []
        polygon_mapping = [-1] * n  # -1 indicates invalid/unmapped polygon
        
        for root, indices in components.items():
            output_idx = len(merged_polygons)

            if len(indices) == 1:
                # Single polygon, no merging needed
                idx = indices[0]
                merged_polygons.append(polygons[idx])
                polygon_mapping[idx] = output_idx
            
            else:
                # Merge all polygons in this component using Shapely union
                merged = validated[indices[0]]
                for idx in indices[1:]:
                    merged = merged.union(validated[idx])

                # Extract polygon coordinates from merged geometry
                if merged.geom_type == 'Polygon':
                    # Single polygon result
                    merged_polygons.append(
                        np.array(merged.exterior.coords).astype(np.int32)
                    )
                    for idx in indices:
                        polygon_mapping[idx] = output_idx

                elif merged.geom_type in ['MultiPolygon', 'GeometryCollection']:
                    # Multiple polygons resulted from merge (e.g., touching at single point)
                    for geom in merged.geoms:
                        if geom.geom_type == 'Polygon':
                            merged_polygons.append(
                                np.array(geom.exterior.coords).astype(np.int32)
                            )
                    # Map all source polygons to first output polygon
                    for idx in indices:
                        polygon_mapping[idx] = output_idx
        
        return merged_polygons, polygon_mapping

    def calculate_polygon_area(self, vertices: np.ndarray) -> float:
        """
        Calculate polygon area using the Shoelace formula (surveyor's formula).
        
        Computes area using coordinate cross products. Works for simple polygons
        (non-self-intersecting) regardless of vertex ordering.

        Args:
            vertices: Array of polygon coordinates with shape (N, 2)
            
        Returns:
            Area of the polygon in square pixels
        """
        x = vertices[:, 0]
        y = vertices[:, 1]
        # Shoelace formula implementation using array operations
        area = 0.5 * np.abs(np.sum(x[:-1] * y[1:]) - np.sum(y[:-1] * x[1:]) + x[-1] * y[0] - y[-1] * x[0])
        return area

    def mask_to_polygon_cv2(self, 
                           mask: np.ndarray, 
                           original_shape: Tuple[int, int]) -> Tuple[List[np.ndarray], np.ndarray]:
        """
        Convert binary segmentation mask to polygon coordinates using OpenCV contours.
        
        Extracts contours from mask, converts them to polygons, and scales coordinates
        back to original image dimensions. Also calculates area percentages for filtering.

        Args:
            mask: Binary mask as numpy array (bool or uint8, 0-255)
            original_shape: Tuple of (height, width) of original image

        Returns:
            Tuple of:
                - scaled_polygons: List of polygon coordinate arrays scaled to original size
                - area_percentages: Array of polygon areas as fraction of mask size
        """
        # Ensure mask is uint8
        if mask.dtype == bool:
            mask_uint8 = mask.astype(np.uint8) * 255
        else:
            mask_uint8 = mask.astype(np.uint8)

        # Find external contours (only outer boundaries)
        contours, _ = cv2.findContours(
            mask_uint8, 
            cv2.RETR_EXTERNAL, 
            cv2.CHAIN_APPROX_SIMPLE
        )

        # Convert contours to polygons (filter out degenerate contours)
        polygons = [
            contour.squeeze() 
            for contour in contours 
            if len(contour) >= self.min_polygon_points
        ]

        # Calculate scaling factors from mask to original image
        orig_height, orig_width = original_shape
        mask_height, mask_width = mask.shape[:2]
        scale_x = orig_width / mask_width
        scale_y = orig_height / mask_height
        
        # Scale polygons and calculate areas
        scaled_polygons = []
        area_percentages = []
        mask_area = mask_height * mask_width

        for poly in polygons:
            # Calculate area on mask coordinates (before scaling)
            area = self.calculate_polygon_area(
                poly if len(poly.shape) > 1 else poly.reshape(1, -1)
            )
            area_percentage = area / mask_area if mask_area > 0 else 0
            area_percentages.append(area_percentage)

            # Scale polygon coordinates to original image size
            if len(poly.shape) == 1:  # Single point edge case
                scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int)
            else:  # Normal case with multiple points
                scaled_poly = np.round(poly * np.array([scale_x, scale_y])).astype(int)
            
            scaled_polygons.append(scaled_poly)

        return scaled_polygons, np.array(area_percentages)


    def process_polygons(self, 
                        poly_masks: np.ndarray, 
                        image_shape: Tuple[int, int], 
                        percentage_threshold: float, 
                        overlap_threshold: float, 
                        iou_threshold: float) -> Tuple[List[np.ndarray], List[Tuple[int, int, int, int]]]:
        """
        Extract polygons from segmentation masks, filter by area, and merge overlapping ones.
        
        Converts masks to polygons, filters out small detections based on area percentage,
        and merges overlapping polygons based on IoU and overlap criteria.

        Args:
            poly_masks: Array of binary segmentation masks from model
            image_shape: Tuple of (height, width) of original image
            percentage_threshold: Minimum polygon area as fraction of image
            overlap_threshold: Minimum overlap ratio for merging polygons
            iou_threshold: Minimum IoU for merging polygons

        Returns:
            Tuple of:
                - merged_polygons: List of polygon coordinate arrays
                - merged_max_mins: List of bounding boxes as (xmin, ymin, xmax, ymax) tuples
        """
        all_polygons = []
        all_area_percentages = []

        # Extract polygons from all masks
        for mask in poly_masks:
            polygons, area_percentages = self.mask_to_polygon_cv2(
                mask=mask, 
                original_shape=image_shape
            )
            all_polygons.extend(polygons)
            all_area_percentages.extend(area_percentages)
        
        all_area_percentages = np.array(all_area_percentages)

        # Filter polygons by minimum area threshold
        if len(all_area_percentages) == 0:
            return [], []
        
        valid_indices = np.where(all_area_percentages > percentage_threshold)[0]
        filtered_polygons = [all_polygons[idx] for idx in valid_indices]
        
        if not filtered_polygons:
            return [], []

        # Merge overlapping polygons
        merged_polygons, _ = self.merge_polygons(
            filtered_polygons, 
            iou_threshold, 
            overlap_threshold
        )

        # Calculate bounding boxes for merged polygons
        merged_max_mins = []
        for poly in merged_polygons:
            if len(poly) > 0:
                xmax, ymax = np.max(poly, axis=0)
                xmin, ymin = np.min(poly, axis=0)
                merged_max_mins.append((xmin, ymin, xmax, ymax))

        return merged_polygons, merged_max_mins

    def get_segmentation(self, image) -> Optional[List[Dict[str, Any]]]:
        """
        Detect and extract ordered text lines and regions from a document image.
        
        Runs the segmentation model on the image, extracts line and region polygons,
        merges overlapping detections, associates lines with regions, and orders them
        for reading sequence.

        Args:
            image: PIL Image object in any mode (will be converted to RGB)

        Returns:
            List of ordered line dictionaries with region associations, or None if
            no lines were detected. Each line dict contains coordinates, region ID,
            and other metadata.
        """
        image_shape = (image.shape[0], image.shape[1])

        # Preprocess image (resize for model input)
        preprocessed_image = preprocess_resize_torch_transform(
            image, 
            max_size=self.max_size
        )

        # Run segmentation model
        try:
            detections = self.model.predict(
                preprocessed_image, 
                threshold=self.confidence_threshold
            )
        except Exception as e:
            print(f"Error during segmentation prediction: {e}")
            return None

        # Separate line and region masks by class ID
        line_mask = detections.mask[detections.class_id == self.class_id_line]
        region_mask = detections.mask[detections.class_id ==  self.class_id_region]

        # Process line polygons
        merged_line_polygons, merged_line_max_mins = self.process_polygons(
            line_mask, 
            image_shape, 
            self.line_percentage_threshold, 
            self.line_overlap_threshold, 
            self.line_iou
        )

        # Process region polygons
        merged_region_polygons, merged_region_max_mins = self.process_polygons(
            region_mask, 
            image_shape, 
            self.region_percentage_threshold, 
            self.region_overlap_threshold, 
            self.region_iou
        )

        # If no lines detected, return None
        if not merged_line_polygons:
            print('No text lines detected from image.')
            return None

        # Prepare line predictions dictionary
        line_preds = {
            'coords': merged_line_polygons, 
            'max_min': merged_line_max_mins
        }

        # Prepare region predictions (or use default if none detected)
        if merged_region_polygons:
            region_preds = []
            for num, (region_polygon, region_max_min) in enumerate(
                zip(merged_region_polygons, merged_region_max_mins)
            ):
                region_preds.append({
                    'coords': region_polygon,
                    'id': str(num),
                    'max_min': region_max_min,
                    'name': 'paragraph',
                    'img_shape': image_shape
                })
        else:
            # No regions detected, create default region covering entire image
            region_preds = get_default_region(image_shape=image_shape)

        # Associate lines with their containing regions
        lines_connected_to_regions = get_line_regions(
            lines=line_preds, 
            regions=region_preds
        )
        
        # Order lines within regions for proper reading sequence
        ordered_lines = order_regions_lines(
            lines=lines_connected_to_regions, 
            regions=region_preds
        )

        return ordered_lines