Spaces:
Sleeping
Sleeping
| import logging | |
| import string | |
| from collections import defaultdict | |
| from typing import Any, List, Union | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from doctr.io.elements import Document | |
| from doctr.models import parseq | |
| from doctr.models._utils import get_language | |
| from doctr.models.detection.predictor import DetectionPredictor | |
| from doctr.models.detection.zoo import detection_predictor | |
| from doctr.models.predictor.base import _OCRPredictor | |
| from doctr.models.recognition.predictor import RecognitionPredictor | |
| from doctr.models.recognition.zoo import recognition_predictor | |
| from doctr.utils.geometry import detach_scores | |
| from PIL import Image, ImageDraw, ImageFont | |
| from sklearn.cluster import DBSCAN | |
| from sklearn.preprocessing import StandardScaler | |
| from torch import nn | |
| confidence_threshold = 0.75 | |
| reco_arch = "printed_v19.pt" | |
| det_arch = "fast_base" | |
| # Configure logging | |
| afterword_symbols = "!?.,:;" | |
| numbers = "0123456789" | |
| other_symbols = string.punctuation + "«»…£€¥¢฿₸₽№°—" | |
| space_symbol = " " | |
| kazakh_letters = "АБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯЁабвгдежзийклмнопрстуфхцчшщъыьэюяёӘҒҚҢӨҰҮІҺәғқңөұүіһ" | |
| english_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" | |
| all_letters = kazakh_letters + english_letters | |
| all_symbols = numbers + other_symbols + space_symbol + all_letters | |
| def get_ocr_predictor( | |
| det_arch: str = det_arch, | |
| reco_arch: str = reco_arch, | |
| pretrained=True, | |
| pretrained_backbone: bool = True, | |
| assume_straight_pages: bool = False, | |
| preserve_aspect_ratio: bool = True, | |
| symmetric_pad: bool = True, | |
| det_bs: int = 2, | |
| reco_bs: int = 128, | |
| detect_orientation: bool = False, | |
| straighten_pages: bool = False, | |
| detect_language: bool = False, | |
| bin_thresh: float = 0.3, | |
| box_thresh: float = 0.3, | |
| ): | |
| device = "cpu" | |
| if torch.backends.mps.is_available(): | |
| device = "mps" | |
| elif torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| logging.info(f"Using device: {device}") | |
| device = torch.device(device) | |
| # Initialize predictor | |
| logging.info(f"Initializing predictor with device: {device}") | |
| reco_model = parseq(pretrained=False, pretrained_backbone=False, vocab=all_symbols) | |
| reco_model.to(device) | |
| reco_params = torch.load(f"./custom/{reco_arch}", map_location=device) | |
| reco_model.load_state_dict(reco_params) | |
| det_predictor = detection_predictor( | |
| det_arch, | |
| pretrained=pretrained, | |
| pretrained_backbone=pretrained_backbone, | |
| batch_size=det_bs, | |
| assume_straight_pages=assume_straight_pages, | |
| preserve_aspect_ratio=preserve_aspect_ratio, | |
| symmetric_pad=symmetric_pad, | |
| ) | |
| # Recognition | |
| reco_predictor = recognition_predictor( | |
| reco_model, | |
| pretrained=pretrained, | |
| pretrained_backbone=pretrained_backbone, | |
| batch_size=reco_bs, | |
| ) | |
| predictor = OCRPredictor( | |
| det_predictor, | |
| reco_predictor, | |
| assume_straight_pages=assume_straight_pages, | |
| preserve_aspect_ratio=preserve_aspect_ratio, | |
| symmetric_pad=symmetric_pad, | |
| detect_orientation=detect_orientation, | |
| straighten_pages=straighten_pages, | |
| detect_language=detect_language, | |
| ) | |
| predictor.det_predictor.model.postprocessor.bin_thresh = bin_thresh | |
| predictor.det_predictor.model.postprocessor.box_thresh = box_thresh | |
| predictor.add_hook(CustomHook()) | |
| return predictor | |
| class OCRPredictor(nn.Module, _OCRPredictor): | |
| """Implements an object able to localize and identify text elements in a set of documents | |
| Args: | |
| ---- | |
| det_predictor: detection module | |
| reco_predictor: recognition module | |
| assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages | |
| without rotated textual elements. | |
| straighten_pages: if True, estimates the page general orientation based on the median line orientation. | |
| Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped | |
| accordingly. Doing so will improve performances for documents with page-uniform rotations. | |
| detect_orientation: if True, the estimated general page orientation will be added to the predictions for each | |
| page. Doing so will slightly deteriorate the overall latency. | |
| detect_language: if True, the language prediction will be added to the predictions for each | |
| page. Doing so will slightly deteriorate the overall latency. | |
| **kwargs: keyword args of `DocumentBuilder` | |
| """ | |
| def __init__( | |
| self, | |
| det_predictor: DetectionPredictor, | |
| reco_predictor: RecognitionPredictor, | |
| assume_straight_pages: bool = True, | |
| straighten_pages: bool = False, | |
| preserve_aspect_ratio: bool = True, | |
| symmetric_pad: bool = True, | |
| detect_orientation: bool = False, | |
| detect_language: bool = False, | |
| **kwargs: Any, | |
| ) -> None: | |
| nn.Module.__init__(self) | |
| self.det_predictor = det_predictor.eval() # type: ignore[attr-defined] | |
| self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined] | |
| _OCRPredictor.__init__( | |
| self, | |
| assume_straight_pages, | |
| straighten_pages, | |
| preserve_aspect_ratio, | |
| symmetric_pad, | |
| detect_orientation, | |
| **kwargs, | |
| ) | |
| self.detect_orientation = detect_orientation | |
| self.detect_language = detect_language | |
| def forward( | |
| self, | |
| pages: List[Union[np.ndarray, torch.Tensor]], | |
| **kwargs: Any, | |
| ) -> Document: | |
| # Dimension check | |
| if any(page.ndim != 3 for page in pages): | |
| raise ValueError( | |
| "incorrect input shape: all pages are expected to be multi-channel 2D images." | |
| ) | |
| origin_page_shapes = [ | |
| page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] | |
| for page in pages | |
| ] | |
| # Localize text elements | |
| loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) | |
| # Detect document rotation and rotate pages | |
| seg_maps = [ | |
| np.where( | |
| out_map > getattr(self.det_predictor.model.postprocessor, "bin_thresh"), | |
| 255, | |
| 0, | |
| ).astype(np.uint8) | |
| for out_map in out_maps | |
| ] | |
| if self.detect_orientation: | |
| general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type] | |
| orientations = [ | |
| {"value": orientation_page, "confidence": None} | |
| for orientation_page in origin_pages_orientations | |
| ] | |
| else: | |
| orientations = None | |
| general_pages_orientations = None | |
| origin_pages_orientations = None | |
| if self.straighten_pages: | |
| pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore | |
| # Forward again to get predictions on straight pages | |
| loc_preds = self.det_predictor(pages, **kwargs) | |
| assert all( | |
| len(loc_pred) == 1 for loc_pred in loc_preds | |
| ), "Detection Model in ocr_predictor should output only one class" | |
| loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds] | |
| # Detach objectness scores from loc_preds | |
| loc_preds, objectness_scores = detach_scores(loc_preds) | |
| # Check whether crop mode should be switched to channels first | |
| channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray) | |
| # Apply hooks to loc_preds if any | |
| for hook in self.hooks: | |
| loc_preds = hook(loc_preds) | |
| # Crop images | |
| crops, loc_preds = self._prepare_crops( | |
| pages, # type: ignore[arg-type] | |
| loc_preds, | |
| channels_last=channels_last, | |
| assume_straight_pages=self.assume_straight_pages, | |
| ) | |
| # Rectify crop orientation and get crop orientation predictions | |
| crop_orientations: Any = [] | |
| # save crops to ./crops | |
| # os.makedirs("./crops", exist_ok=True) | |
| # for i, crop in enumerate(crops[0]): | |
| # Image.fromarray(crop).save(f"./crops/{i}.png") | |
| # if not self.assume_straight_pages: | |
| # crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds) | |
| # crop_orientations = [ | |
| # {"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations | |
| # ] | |
| # Identify character sequences | |
| word_preds = self.reco_predictor( | |
| [crop for page_crops in crops for crop in page_crops], **kwargs | |
| ) | |
| if not crop_orientations: | |
| crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds] | |
| boxes, text_preds, crop_orientations = self._process_predictions( | |
| loc_preds, word_preds, crop_orientations | |
| ) | |
| if self.detect_language: | |
| languages = [ | |
| get_language(" ".join([item[0] for item in text_pred])) | |
| for text_pred in text_preds | |
| ] | |
| languages_dict = [ | |
| {"value": lang[0], "confidence": lang[1]} for lang in languages | |
| ] | |
| else: | |
| languages_dict = None | |
| out = self.doc_builder( | |
| pages, # type: ignore[arg-type] | |
| boxes, | |
| objectness_scores, | |
| text_preds, | |
| origin_page_shapes, # type: ignore[arg-type] | |
| crop_orientations, | |
| orientations, | |
| languages_dict, | |
| ) | |
| return out | |
| class CustomHook: | |
| def __call__(self, loc_preds): | |
| # Manipulate the location predictions here | |
| # 1. The outpout structure needs to be the same as the input location predictions | |
| # 2. Be aware that the coordinates are relative and needs to be between 0 and 1 | |
| # return np.array([self.order_bbox_points(point) for loc_pred in loc_preds for point in loc_pred ]) | |
| # iterate over each page and each box | |
| answer = [] | |
| for page_idx, page_boxes in enumerate(loc_preds): | |
| bboxes = [] | |
| for box_idx, box in enumerate(page_boxes): | |
| box = self.order_bbox_points(box) | |
| bboxes.append(box) | |
| answer.append(bboxes) | |
| return np.array(answer) | |
| def order_bbox_points(self, points): | |
| """ | |
| Orders a list of four (x, y) points in the following order: | |
| top-left, top-right, bottom-right, bottom-left. | |
| Args: | |
| points (list of tuples): List of four (x, y) tuples. | |
| Returns: | |
| list of tuples: Ordered list of four (x, y) tuples. | |
| """ | |
| if len(points) != 4: | |
| raise ValueError( | |
| "Exactly four points are required to define a quadrilateral." | |
| ) | |
| # Convert points to NumPy array for easier manipulation | |
| pts = np.array(points) | |
| # Compute the sum and difference of the points | |
| sum_pts = pts.sum(axis=1) | |
| diff_pts = np.diff(pts, axis=1).flatten() | |
| # Initialize ordered points list | |
| ordered = [None] * 4 | |
| # Top-Left point has the smallest sum | |
| ordered[0] = tuple(pts[np.argmin(sum_pts)]) | |
| # Bottom-Right point has the largest sum | |
| ordered[2] = tuple(pts[np.argmax(sum_pts)]) | |
| # Top-Right point has the smallest difference | |
| ordered[1] = tuple(pts[np.argmin(diff_pts)]) | |
| # Bottom-Left point has the largest difference | |
| ordered[3] = tuple(pts[np.argmax(diff_pts)]) | |
| return ordered | |
| def geometry_to_coordinates(geometry, img_width, img_height): | |
| if len(geometry) == 2: | |
| (x0_rel, y0_rel), (x1_rel, y1_rel) = geometry | |
| x0 = int(x0_rel * img_width) | |
| y0 = int(y0_rel * img_height) | |
| x1 = int(x1_rel * img_width) | |
| y1 = int(y1_rel * img_height) | |
| # Bounding box with four corners | |
| all_four = [[x0, y0], [x1, y0], [x1, y1], [x0, y1]] | |
| return all_four | |
| else: | |
| # Bounding box with four corners | |
| all_four = [[int(x * img_width), int(y * img_height)] for x, y in geometry] | |
| return all_four | |
| def page_to_coordinates(page_export): | |
| coordinates = [] | |
| img_height, img_width = page_export["dimensions"] | |
| for block in page_export["blocks"]: | |
| for line in block["lines"]: | |
| for word in line["words"]: | |
| if ( | |
| word["confidence"] < confidence_threshold | |
| and len(word["value"].strip()) > 1 | |
| ): | |
| logging.warning( | |
| f"Skipping word with low confidence: {word['value']} confidence {word['confidence']}" | |
| ) | |
| continue | |
| all_four = geometry_to_coordinates( | |
| word["geometry"], img_width, img_height | |
| ) | |
| coordinates.append((all_four, word["value"], word["confidence"])) | |
| return (coordinates, img_width, img_height) | |
| def draw_boxes_with_labels(image, coordinates, font_path): | |
| """Бастапқы суретке шекаралар үстіне кішкентай белгілерді қою. | |
| Args: | |
| image: Бастапқы сурет (numpy массиві). | |
| out: predictor([image]) нәтижесі. | |
| font_path: TrueType қаріп файлының жолы. | |
| Returns: | |
| Шекаралар және белгілер қойылған сурет. | |
| """ | |
| # Суретті PIL форматына түрлендіреміз | |
| img_with_boxes = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| img_pil = Image.fromarray(img_with_boxes) | |
| draw = ImageDraw.Draw(img_pil) | |
| for coords, word, score in coordinates: | |
| # poligon | |
| coords = [(x, y) for x, y in coords] | |
| text_x, text_y = ( | |
| min(coords, key=lambda x: x[0])[0], | |
| min(coords, key=lambda x: x[1])[1], | |
| ) | |
| draw.polygon(coords, outline=(0, 255, 0, 125), width=1) | |
| font = ImageFont.truetype(font_path, 10) | |
| draw.text((text_x, max(text_y - 10, 0)), word, font=font, fill=(255, 0, 0)) | |
| # Суретті қайтадан OpenCV форматына түрлендіреміз | |
| img_with_boxes = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) | |
| # Суретті қайтарамыз | |
| return img_with_boxes | |
| def generate_line_points(bbox, num_points=10): | |
| """ | |
| Generates multiple points along the line connecting the left and right centers of a bounding box. | |
| Parameters: | |
| - bbox: List of four points [[x0, y0], [x1, y1], [x2, y2], [x3, y3]] | |
| in the order: TopLeft, TopRight, BottomRight, BottomLeft. | |
| - num_points: Number of points to generate along the line. | |
| Returns: | |
| - List of (x, y) tuples. | |
| """ | |
| # Calculate left center (midpoint of TopLeft and BottomLeft) | |
| left_center_x = (bbox[0][0] + bbox[3][0]) / 2 | |
| left_center_y = (bbox[0][1] + bbox[3][1]) / 2 | |
| # Calculate right center (midpoint of TopRight and BottomRight) | |
| right_center_x = (bbox[1][0] + bbox[2][0]) / 2 | |
| right_center_y = (bbox[1][1] + bbox[2][1]) / 2 | |
| # Generate linearly spaced points between left center and right center | |
| x_values = np.linspace(left_center_x, right_center_x, num_points) | |
| y_values = np.linspace(left_center_y, right_center_y, num_points) | |
| points = list(zip(x_values, y_values)) | |
| return points | |
| def ocr_to_txt(coordinates): | |
| """ | |
| Converts OCR output to a structured text file with lines using multiple points along connecting lines. | |
| Inserts empty lines when there's significant vertical spacing between lines. | |
| Parameters: | |
| - coordinates: List of tuples containing bounding box coordinates, word value, and score. | |
| Each tuple is (([[x0, y0], [x1, y1], [x2, y2], [x3, y3]]), word, score) | |
| - img_width: Width of the image in pixels. | |
| - img_height: Height of the image in pixels. | |
| - output_file: Path to the output text file. | |
| """ | |
| # Step 1: Compute multiple points for each word | |
| all_points = [] | |
| words = [] | |
| scaler = StandardScaler() | |
| points_per_word = 25 # Number of points to generate per word | |
| for bbox, word, score in coordinates: | |
| points = generate_line_points(bbox, num_points=points_per_word) | |
| all_points.extend(points) | |
| words.append( | |
| { | |
| "bbox": bbox, | |
| "word": word, | |
| "score": score, | |
| "points": points, # Store the multiple points | |
| } | |
| ) | |
| # Step 2: Scale the points | |
| scaled_points = scaler.fit_transform(all_points) | |
| scaled_points = [(c[0] / 5, c[1]) for c in scaled_points] | |
| scaled_points = np.array(scaled_points) | |
| # Step 3: Cluster points using DBSCAN | |
| # Parameters for DBSCAN can be tuned based on the specific OCR output | |
| # eps determines the maximum distance between two samples for them to be considered as in the same neighborhood | |
| # min_samples is set to the number of points per word to ensure entire words are clustered together | |
| db = DBSCAN(min_samples=2, eps=0.05).fit(scaled_points) # eps might need adjustment | |
| labels = db.labels_ | |
| # Map each point to its cluster label | |
| point_labels = labels.tolist() | |
| # Step 4: Assign words to clusters based on their points | |
| label_to_words = defaultdict(list) | |
| current_point = 0 # To keep track of which point belongs to which word | |
| for word in words: | |
| word_labels = point_labels[current_point : current_point + points_per_word] | |
| current_point += points_per_word | |
| # Count the frequency of each label in the word's points | |
| label_counts = defaultdict(int) | |
| for lbl in word_labels: | |
| label_counts[lbl] += 1 | |
| # Assign the word to the most frequent label | |
| # If multiple labels have the same highest count, choose the smallest label (ignoring -1 for noise) | |
| if label_counts: | |
| # Exclude noise label (-1) when possible | |
| filtered_labels = {k: v for k, v in label_counts.items() if k != -1} | |
| if filtered_labels: | |
| assigned_label = max(filtered_labels, key=filtered_labels.get) | |
| else: | |
| assigned_label = -1 # Assign to noise | |
| label_to_words[assigned_label].append(word) | |
| # Remove noise cluster if present | |
| if -1 in label_to_words: | |
| print( | |
| f"Warning: {len(label_to_words[-1])} words assigned to noise cluster and will be ignored." | |
| ) | |
| del label_to_words[-1] | |
| # Step 5: Sort words within each line | |
| sorted_lines = [] | |
| line_heights = [] # To store heights of each line for median calculation | |
| line_y_bounds = [] # To store min and max y for each line | |
| for label, line_words in label_to_words.items(): | |
| # Sort words based on their leftmost x-coordinate | |
| line_words_sorted = sorted( | |
| line_words, key=lambda w: min(point[0] for point in w["points"]) | |
| ) | |
| sorted_lines.append(line_words_sorted) | |
| # Compute y-bounds for the line | |
| y_values = [] | |
| for word in line_words_sorted: | |
| y_coords = [point[1] for point in word["bbox"]] | |
| y_min = min(y_coords) | |
| y_max = max(y_coords) | |
| y_values.append([y_min, y_max]) | |
| y_values = np.array(y_values) | |
| # Compute the median y-coordinates for the line by sorting only with y_min | |
| line_min_y_median = np.median(y_values[:, 0]) | |
| line_max_y_median = np.median(y_values[:, 1]) | |
| line_heights.append(line_max_y_median - line_min_y_median) | |
| line_y_bounds.append((line_min_y_median, line_max_y_median)) | |
| # Step 6: Sort lines from top to bottom based on the average y-coordinate of their words | |
| sorted_lines, line_heights, line_y_bounds = zip( | |
| *sorted( | |
| zip(sorted_lines, line_heights, line_y_bounds), | |
| key=lambda item: np.median( | |
| [np.mean([p[1] for p in w["bbox"]]) for w in item[0]] | |
| ), | |
| ) | |
| ) | |
| sorted_lines = list(sorted_lines) | |
| line_heights = list(line_heights) | |
| line_y_bounds = list(line_y_bounds) | |
| # Step 8: Write sorted lines to the output text file with empty lines where necessary | |
| output_text = "" | |
| previous_line_median_y = None # To track the max y of the previous line | |
| for idx, line in enumerate(sorted_lines): | |
| # Compute current line's min y | |
| current_line_min_y_median = line_y_bounds[idx][0] | |
| current_line_max_y_median = line_y_bounds[idx][1] | |
| current_line_median_height = line_heights[idx] | |
| current_line_median_y = ( | |
| current_line_min_y_median + current_line_max_y_median | |
| ) / 2 | |
| if previous_line_median_y is not None: | |
| # Compute vertical distance between lines | |
| vertical_distance = current_line_median_y - previous_line_median_y | |
| median_height = ( | |
| current_line_median_height + previous_line_median_height | |
| ) / 2 | |
| # If the vertical distance is greater than the median height, insert an empty line | |
| if vertical_distance > median_height * 2: | |
| output_text += "\n" # Insert empty line | |
| # Write the current line's text | |
| line_text = " ".join([w["word"] for w in line]) | |
| output_text += line_text + "\n" | |
| # Update the previous_line_max_y for the next iteration | |
| previous_line_median_y = current_line_median_y | |
| previous_line_median_height = current_line_median_height | |
| return output_text | |