MikkoLipsanen commited on
Commit
eee8423
·
verified ·
1 Parent(s): df891fc

Update code to use 202509_onnx_small model

Browse files
Files changed (1) hide show
  1. onnx_text_recognition.py +341 -93
onnx_text_recognition.py CHANGED
@@ -1,115 +1,363 @@
1
- from optimum.onnxruntime import ORTModelForVision2Seq
 
2
  from transformers import TrOCRProcessor
 
3
  import numpy as np
4
  import onnxruntime
5
  import math
 
6
  import cv2
7
  import os
8
 
9
  class TextRecognition:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def __init__(self,
11
- processor_path,
12
- model_path,
13
- device = 'cuda:0',
14
- half_precision = False,
15
- line_threshold = 10):
16
- self.device = device
17
- self.half_precision = half_precision
18
- self.line_threshold = line_threshold
19
- self.processor_path = processor_path
20
  self.model_path = model_path
21
- self.processor = self.init_processor()
22
- self.recognition_model = self.init_recognition_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- def init_processor(self):
25
- """Function for initializing the processor."""
 
26
  try:
27
- processor = TrOCRProcessor.from_pretrained(self.processor_path, token=True)
28
- return processor
 
 
 
 
 
 
 
 
29
  except Exception as e:
30
- print('Failed to initialize processor: %s' % e)
31
-
32
- def init_recognition_model(self):
33
- """Function for initializing the text detection model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  sess_options = onnxruntime.SessionOptions()
35
- sess_options.intra_op_num_threads = 3
36
- sess_options.inter_op_num_threads = 3
 
 
 
 
 
 
 
 
 
37
  try:
38
- recognition_model = ORTModelForVision2Seq.from_pretrained(self.model_path, token=True, session_options=sess_options, provider="CUDAExecutionProvider")
39
- return recognition_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  except Exception as e:
41
- print('Failed to load the text recognition model: %s' % e)
42
-
43
- def crop_line(self, image, polygon, height, width):
44
- """Crops predicted text line based on the polygon coordinates
45
- and returns binarised text line image."""
46
- polygon = np.array([[int(lst[0]), int(lst[1])] for lst in polygon], dtype=np.int32)
47
- rect = cv2.boundingRect(polygon)
48
- cropped_image = image[rect[1]: rect[1] + rect[3], rect[0]: rect[0] + rect[2]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  mask = np.zeros([cropped_image.shape[0], cropped_image.shape[1]], dtype=np.uint8)
50
- cv2.drawContours(mask, [polygon- np.array([[rect[0],rect[1]]])], -1, (255, 255, 255), -1, cv2.LINE_AA)
51
- res = cv2.bitwise_and(cropped_image, cropped_image, mask = mask)
52
- wbg = np.ones_like(cropped_image, np.uint8)*255
53
- cv2.bitwise_not(wbg,wbg, mask=mask)
54
- # Overlap the resulted cropped image on the white background
55
- dst = wbg+res
56
- return dst
57
-
58
- def crop_lines(self, polygons, image, height, width):
59
- """Returns a list of line images cropped following the detected polygon coordinates."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  cropped_lines = []
61
  for i, polygon in enumerate(polygons):
62
- cropped_line = self.crop_line(image, polygon, height, width)
63
- cropped_lines.append(cropped_line)
 
 
 
64
  return cropped_lines
65
 
66
- def get_scores(self, lgscores):
67
- """Get exponent of log scores."""
68
- scores = []
69
- for lgscore in lgscores:
70
- score = math.exp(lgscore)
71
- scores.append(score)
72
- return scores
73
-
74
- def predict_text(self, cropped_lines):
75
- """Functions for predicting text content from the cropped line images."""
76
- pixel_values = self.processor(cropped_lines, return_tensors="pt").pixel_values
77
- generated_dict = self.recognition_model.generate(pixel_values.to(self.device), max_new_tokens=128, return_dict_in_generate=True, output_scores=True)
78
- generated_ids, lgscores = generated_dict['sequences'], generated_dict['sequences_scores']
79
- scores = self.get_scores(lgscores.tolist())
80
- generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
81
- return scores, generated_text
82
-
83
- def get_text_lines(self, cropped_lines):
84
- scores, generated_text = [], []
85
- if len(cropped_lines) <= self.line_threshold:
86
- scores, generated_text = self.predict_text(cropped_lines)
87
- else:
88
- n = math.ceil(len(cropped_lines) / self.line_threshold)
89
- for i in range(n):
90
- print(i)
91
- start = int(i * self.line_threshold)
92
- end = int(min(start + self.line_threshold, len(cropped_lines)))
93
- sc, gt = self.predict_text(cropped_lines[start:end])
94
- scores += sc
95
- print(gt)
96
- generated_text += gt
97
- return scores, generated_text
98
-
99
- def get_res_dict(self, polygons, generated_text, height, width, image_name, line_confs, scores):
100
- """Combines the results in a dictionary form."""
101
- line_dicts = []
102
- for i in range(len(generated_text)):
103
- line_dict = {'polygon': polygons[i], 'text': generated_text[i], 'conf': line_confs[i], 'text_conf':scores[i]}
104
- line_dicts.append(line_dict)
105
- lines_dict = {'img_name': image_name, 'height': height, 'width': width, 'text_lines': line_dicts}
106
- return lines_dict
107
-
108
- def process_lines(self, polygons, image, height, width):
109
- # Crop line images
110
- print('starting text generation')
111
- cropped_lines = self.crop_lines(polygons, image, height, width)
112
- print('cropped lines')
113
- # Get text predictions
114
- scores, generated_text = self.get_text_lines(cropped_lines)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  return generated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import VisionEncoderDecoderConfig
2
+ from typing import List, Tuple, Optional
3
  from transformers import TrOCRProcessor
4
+ from pathlib import Path
5
  import numpy as np
6
  import onnxruntime
7
  import math
8
+ import time
9
  import cv2
10
  import os
11
 
12
  class TextRecognition:
13
+ """
14
+ ONNX-based text recognition class using TrOCR for handwritten text recognition.
15
+
16
+ Processes text line images through an encoder-decoder architecture, supporting
17
+ batch processing and CUDA acceleration.
18
+
19
+ Args:
20
+ model_path: Path to the model directory containing ONNX models and config
21
+ device: Device identifier (default: 'cuda:0')
22
+ batch_size: Number of lines to process in parallel (default: 10)
23
+ img_height: Target height for input images (default: 192)
24
+ img_width: Target width for input images (default: 1024)
25
+ max_length: Maximum sequence length for generation (default: 128)
26
+ """
27
  def __init__(self,
28
+ model_path: str,
29
+ device: str = 'cuda:0',
30
+ batch_size: int = 10,
31
+ img_height: int = 192,
32
+ img_width: int = 1024,
33
+ max_length: int = 128):
 
 
 
34
  self.model_path = model_path
35
+ self.device = device
36
+ self.batch_size = batch_size
37
+ self.img_height = img_height
38
+ self.img_width = img_width
39
+ self.max_length = max_length
40
+
41
+ # Validate model path
42
+ if not os.path.exists(self.model_path):
43
+ raise FileNotFoundError(f"Model path does not exist: {model_path}")
44
+
45
+ self.init_processor()
46
+ self.init_recognition_model()
47
+
48
+ def init_processor(self) -> None:
49
+ """
50
+ Initialize the TrOCR processor with custom image dimensions.
51
 
52
+ Raises:
53
+ Exception: If processor initialization fails
54
+ """
55
  try:
56
+ self.processor = TrOCRProcessor.from_pretrained(
57
+ str(self.model_path),
58
+ use_fast=True,
59
+ do_resize=True,
60
+ size={
61
+ 'height': self.img_height,
62
+ 'width': self.img_width
63
+ }
64
+ )
65
+ print(f"✓ Processor loaded with custom image size: {self.img_height}x{self.img_width}")
66
  except Exception as e:
67
+ raise RuntimeError(f'Failed to initialize processor: {e}')
68
+
69
+
70
+ def init_recognition_model(self) -> None:
71
+ """
72
+ Initialize the ONNX encoder and decoder models with optimized settings.
73
+
74
+ Raises:
75
+ FileNotFoundError: If model files are not found
76
+ RuntimeError: If model loading fails
77
+ """
78
+ encoder_path = os.path.join(self.model_path, "encoder_model.onnx")
79
+ decoder_path = os.path.join(self.model_path, "decoder_model.onnx")
80
+
81
+ if not os.path.exists(encoder_path):
82
+ raise FileNotFoundError(f"Encoder model not found: {encoder_path}")
83
+ if not os.path.exists(decoder_path):
84
+ raise FileNotFoundError(f"Decoder model not found: {decoder_path}")
85
+
86
+ # Session options for better performance
87
  sess_options = onnxruntime.SessionOptions()
88
+ sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
89
+ sess_options.intra_op_num_threads = 4
90
+
91
+ providers = [
92
+ 'CUDAExecutionProvider',
93
+ 'CPUExecutionProvider'
94
+ ]
95
+
96
+ # Load model config
97
+ self.config = VisionEncoderDecoderConfig.from_pretrained(str(self.model_path))
98
+
99
  try:
100
+ print("Loading encoder...")
101
+ self.encoder = onnxruntime.InferenceSession(
102
+ str(encoder_path),
103
+ sess_options=sess_options,
104
+ providers=providers
105
+ )
106
+
107
+ print("Loading decoder...")
108
+ self.decoder = onnxruntime.InferenceSession(
109
+ str(decoder_path),
110
+ sess_options=sess_options,
111
+ providers=providers
112
+ )
113
+
114
+ # Report which provider is actually being used
115
+ encoder_provider = self.encoder.get_providers()[0]
116
+ decoder_provider = self.decoder.get_providers()[0]
117
+ print(f"✓ Using execution provider: Encoder={encoder_provider}, Decoder={decoder_provider}")
118
+
119
  except Exception as e:
120
+ raise RuntimeError(f'Failed to load recognition models: {e}')
121
+
122
+ def crop_line(self, image: np.ndarray, polygon: List[List[float]]) -> Optional[np.ndarray]:
123
+ """
124
+ Crop a text line from an image based on polygon coordinates.
125
+
126
+ Creates a masked crop where the polygon area contains the original image
127
+ and the background is filled with white pixels.
128
+
129
+ Args:
130
+ image: Source image as numpy array
131
+ polygon: List of [x, y] coordinate pairs defining the text line region
132
+
133
+ Returns:
134
+ Cropped and masked text line image, or None if polygon is invalid
135
+ """
136
+ # Convert polygon to integer coordinates
137
+ polygon_array = np.array([[int(pt[0]), int(pt[1])] for pt in polygon], dtype=np.int32)
138
+
139
+ # Get bounding rectangle
140
+ rect = cv2.boundingRect(polygon_array)
141
+ x, y, w, h = rect
142
+
143
+ # Validate rectangle
144
+ if w <= 0 or h <= 0:
145
+ print(f"Warning: Invalid bounding rect dimensions: {w}x{h}")
146
+ return None
147
+
148
+ # Crop image to bounding rectangle
149
+ cropped_image = image[y:y + h, x:x + w]
150
+
151
+ if cropped_image.size == 0:
152
+ print(f"Warning: Empty cropped image at rect {rect}")
153
+ return None
154
+
155
+ # Create mask for the polygon region
156
  mask = np.zeros([cropped_image.shape[0], cropped_image.shape[1]], dtype=np.uint8)
157
+
158
+ # Adjust polygon coordinates relative to the cropped region
159
+ polygon_offset = polygon_array - np.array([[x, y]])
160
+ cv2.drawContours(mask, [polygon_offset], -1, (255, 255, 255), -1, cv2.LINE_AA)
161
+
162
+ # Extract the polygon region from the cropped image
163
+ masked_region = cv2.bitwise_and(cropped_image, cropped_image, mask=mask)
164
+
165
+ # Create white background
166
+ white_background = np.ones_like(cropped_image, np.uint8) * 255
167
+ cv2.bitwise_not(white_background, white_background, mask=mask)
168
+
169
+ # Overlay the masked region on white background
170
+ result = white_background + masked_region
171
+
172
+ return result
173
+
174
+ def crop_lines(self, polygons: List[List[List[float]]], image: np.ndarray) -> List[np.ndarray]:
175
+ """
176
+ Crop multiple text lines from an image.
177
+
178
+ Args:
179
+ polygons: List of polygon coordinate lists
180
+ image: Source image
181
+
182
+ Returns:
183
+ List of cropped text line images (excluding any failed crops)
184
+ """
185
  cropped_lines = []
186
  for i, polygon in enumerate(polygons):
187
+ cropped_line = self.crop_line(image, polygon)
188
+ if cropped_line is not None:
189
+ cropped_lines.append(cropped_line)
190
+ else:
191
+ print(f"Warning: Failed to crop line {i}")
192
  return cropped_lines
193
 
194
+ def encode(self, pixel_values: np.ndarray) -> np.ndarray:
195
+ """
196
+ Encode image pixel values into hidden states using the vision encoder.
197
+
198
+ Args:
199
+ pixel_values: Preprocessed image tensor from TrOCRProcessor
200
+ Shape: (batch_size, channels, height, width)
201
+
202
+ Returns:
203
+ Encoder hidden states for input to the decoder
204
+ Shape: (batch_size, sequence_length, hidden_size)
205
+
206
+ Raises:
207
+ RuntimeError: If encoding fails
208
+ """
209
+ try:
210
+ encoder_outputs = self.encoder.run(
211
+ None,
212
+ {"pixel_values": pixel_values}
213
+ )[0]
214
+ return encoder_outputs
215
+ except Exception as e:
216
+ raise RuntimeError(f'Failed to encode input: {e}')
217
+
218
+ def generate(self, encoder_outputs: np.ndarray, batch_size: int) -> np.ndarray:
219
+ """
220
+ Generate text tokens using autoregressive decoding with early stopping.
221
+
222
+ Implements per-sequence early stopping: sequences that generate EOS tokens
223
+ stop producing new tokens while others continue, improving efficiency.
224
+
225
+ Args:
226
+ encoder_outputs: Hidden states from the encoder
227
+ Shape: (batch_size, sequence_length, hidden_size)
228
+ batch_size: Number of sequences in the batch
229
+
230
+ Returns:
231
+ Generated token IDs including start and end tokens
232
+ Shape: (batch_size, generated_length)
233
+
234
+ Raises:
235
+ RuntimeError: If generation fails
236
+ """
237
+ try:
238
+ # Initialize decoder input with start tokens
239
+ decoder_input_ids = np.full(
240
+ (batch_size, 1),
241
+ self.config.decoder_start_token_id,
242
+ dtype=np.int64
243
+ )
244
+
245
+ # Track which sequences have finished
246
+ finished = np.zeros(batch_size, dtype=bool)
247
+
248
+ for step in range(self.max_length):
249
+ # Run decoder to get next token logits
250
+ decoder_outputs = self.decoder.run(
251
+ None,
252
+ {
253
+ "input_ids": decoder_input_ids,
254
+ "encoder_hidden_states": encoder_outputs
255
+ }
256
+ )[0]
257
+
258
+ # Get most likely next token for each sequence
259
+ next_token_logits = decoder_outputs[:, -1, :]
260
+ next_tokens = np.argmax(next_token_logits, axis=-1)
261
+
262
+ # Check if any sequences just generated EOS token
263
+ just_finished = (next_tokens == self.config.eos_token_id)
264
+ finished = finished | just_finished
265
+
266
+ ## Replace tokens with PAD for already finished sequences
267
+ next_tokens[finished] = self.config.pad_token_id
268
+
269
+ # Append new tokens to the sequence
270
+ next_tokens = next_tokens.reshape(-1, 1)
271
+ decoder_input_ids = np.concatenate([decoder_input_ids, next_tokens], axis=1)
272
+
273
+ # Stop when all sequences have finished
274
+ if np.all(finished):
275
+ break
276
+
277
+ return decoder_input_ids
278
+
279
+ except Exception as e:
280
+ raise RuntimeError(f'Failed to generate output ids: {e}')
281
+
282
+ def predict_text(self, cropped_lines: List[np.ndarray]) -> List[str]:
283
+ """
284
+ Predict text content from cropped line images.
285
+
286
+ Args:
287
+ cropped_lines: List of cropped text line images
288
+
289
+ Returns:
290
+ List of predicted text strings
291
+
292
+ Raises:
293
+ RuntimeError: If prediction fails
294
+ """
295
+ try:
296
+ # Process image with TrOCR processor
297
+ # Use 'pt' (PyTorch) then convert to numpy, as 'np' is not supported by fast processors
298
+ pixel_values = self.processor(cropped_lines, return_tensors="pt").pixel_values
299
+ pixel_values = pixel_values.numpy()
300
+ batch_size = pixel_values.shape[0]
301
+
302
+ #Encode images to hidden states
303
+ encoder_hidden_states = self.encode(pixel_values)
304
+
305
+ # Generate token sequences
306
+ generated_ids = self.generate(encoder_hidden_states, batch_size)
307
+
308
+ # Decode tokens to text
309
+ texts = self.processor.batch_decode(
310
+ generated_ids,
311
+ skip_special_tokens=True,
312
+ clean_up_tokenization_spaces=False
313
+ )
314
+
315
+ return texts
316
+
317
+ except Exception as e:
318
+ raise RuntimeError(f'Failed to predict text: {e}')
319
+
320
+ def get_text_lines(self, cropped_lines: List[np.ndarray]) -> List[str]:
321
+ """
322
+ Process text lines in batches to manage memory efficiently.
323
+
324
+ Args:
325
+ cropped_lines: List of all cropped line images
326
+
327
+ Returns:
328
+ List of predicted text strings for all lines
329
+ """
330
+ generated_text = []
331
+
332
+ # Process in batches
333
+ for i in range(0, len(cropped_lines), self.batch_size):
334
+ batch = cropped_lines[i:i + self.batch_size]
335
+ texts = self.predict_text(batch)
336
+ generated_text.extend(texts)
337
+
338
  return generated_text
339
+
340
+ def process_lines(self,
341
+ polygons: List[List[List[float]]],
342
+ image: np.ndarray) -> List[str]:
343
+ """
344
+ Complete pipeline: crop text lines and predict their content.
345
+
346
+ Args:
347
+ polygons: List of polygon coordinate lists defining text line regions
348
+ image: Source document image
349
+
350
+ Returns:
351
+ List of predicted text strings for each valid line
352
+ """
353
+ # Crop line images from the document
354
+ cropped_lines = self.crop_lines(polygons, image)
355
+
356
+ if not cropped_lines:
357
+ print("Warning: No valid cropped lines to process")
358
+ return []
359
+
360
+ # Get text predictions for all lines
361
+ generated_text = self.get_text_lines(cropped_lines)
362
+
363
+ return generated_text