Spaces:
Running
on
T4
Running
on
T4
| from huggingface_hub import login, snapshot_download, hf_hub_download | |
| from typing import Optional, Tuple, Dict, Any | |
| from transformers import TrOCRProcessor | |
| from datetime import datetime | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime | |
| import tempfile | |
| import logging | |
| import torch | |
| import time | |
| import json | |
| import os | |
| from plotting_functions import PlotHTR | |
| from segment_image import SegmentImage | |
| from onnx_text_recognition import TextRecognition | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.StreamHandler() # Explicit stdout handler for HF Spaces | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Log startup info for debugging in HF Spaces | |
| logger.info("="*50) | |
| logger.info("HTR Application Starting") | |
| logger.info(f"Python version: {os.sys.version}") | |
| logger.info(f"Running on Hugging Face Spaces: {os.getenv('SPACE_ID', 'Local')}") | |
| logger.info("="*50) | |
| # Configuration from environment variables | |
| class Config: | |
| """Application configuration from environment variables.""" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| SEGMENTATION_MAX_SIZE = 768 | |
| RECOGNITION_BATCH_SIZE = 10 | |
| SEGMENTATION_CONFIDENCE_THRESHOLD = 0.15 | |
| SEGMENTATION_LINE_PRECENTAGE_THRESHOLD = 7e-05 | |
| SEGMENTATION_REGION_PRECENTAGE_THRESHOLD = 7e-05 | |
| SEGMENTATION_LINE_IOU = 0.3 | |
| SEGMENTATION_REGION_IOU = 0.3 | |
| SEGMENTATION_LINE_OVERLAP_THRESHOLD = 0.5 | |
| SEGMENTATION_REGION_OVERLAP_THRESHOLD = 0.5 | |
| ALLOWED_SOURCES = ("https://astia.narc.fi, /tmp/gradio") | |
| # Model paths | |
| TROCR_MODEL_REPO = "Kansallisarkisto/multicentury-htr-model-small-onnx" | |
| SEGMENTATION_MODEL_REPO = "Kansallisarkisto/rfdetr_textline_textregion_detection_model" | |
| SEGMENTATION_MODEL_FILE = "rfdetr_text_seg_model_202510.pth" | |
| # Login to HuggingFace if token is available | |
| if Config.HF_TOKEN: | |
| try: | |
| login(token=Config.HF_TOKEN, add_to_git_credential=True) | |
| logger.info("β Logged in to HuggingFace") | |
| except Exception as e: | |
| logger.warning(f"Failed to login to HuggingFace: {e}") | |
| def download_models() -> Tuple[str, str]: | |
| """ | |
| Download required models from HuggingFace Hub. | |
| Returns: | |
| Tuple of (text_recognition_model_path, segmentation_model_path) | |
| Raises: | |
| RuntimeError: If model download fails | |
| """ | |
| try: | |
| logger.info("Downloading text recognition model...") | |
| trocr_path = snapshot_download(repo_id=Config.TROCR_MODEL_REPO) | |
| logger.info(f"β Text recognition model downloaded to {trocr_path}") | |
| logger.info("Downloading segmentation model...") | |
| seg_path = hf_hub_download( | |
| repo_id=Config.SEGMENTATION_MODEL_REPO, | |
| filename=Config.SEGMENTATION_MODEL_FILE | |
| ) | |
| logger.info(f"β Segmentation model downloaded to {seg_path}") | |
| return trocr_path, seg_path | |
| except Exception as e: | |
| logger.error(f"Failed to download models: {e}") | |
| raise RuntimeError(f"Model download failed: {e}") | |
| # Download models | |
| TROCR_MODEL_PATH, SEGMENTATION_MODEL_PATH = download_models() | |
| # Log CUDA availability | |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| class HTRPipeline: | |
| """ | |
| Handwritten Text Recognition pipeline combining segmentation and recognition. | |
| This class manages the initialization and execution of document segmentation | |
| and text recognition models. | |
| """ | |
| def __init__(self, | |
| segmentation_model_path: str, | |
| recognition_model_path: str, | |
| segmentation_max_size: int = 768, | |
| recognition_batch_size: int = 10, | |
| segmentation_confidence_threshold: float = 0.15, | |
| segmentation_line_percentage_threshold: float = 7e-05, | |
| segmentation_region_percentage_threshold: float = 7e-05, | |
| segmentation_line_iou: float = 0.3, | |
| segmentation_region_iou: float = 0.3, | |
| segmentation_line_overlap_threshold: float = 0.5, | |
| segmentation_region_overlap_threshold: float = 0.5 | |
| ): | |
| """ | |
| Initialize HTR pipeline with segmentation and recognition models. | |
| Args: | |
| segmentation_model_path: Path to segmentation model weights | |
| recognition_model_path: Path to recognition model directory | |
| segmentation_max_size: Maximum image dimension for segmentation | |
| recognition_batch_size: Batch size for text recognition | |
| segmentation_confidence_threshold: Minimum confidence score for detections | |
| segmentation_line_percentage_threshold: Minimum polygon area as fraction of image area for lines | |
| segmentation_region_percentage_threshold: Minimum polygon area as fraction of image area for regions | |
| segmentation_line_iou: IoU threshold for merging overlapping line polygons | |
| segmentation_region_iou: IoU threshold for merging overlapping region polygons | |
| segmentation_line_overlap_threshold: Area overlap ratio threshold for merging lines | |
| segmentation_region_overlap_threshold: Area overlap ratio threshold for merging regions | |
| """ | |
| self.segmenter = self._init_segmenter(segmentation_model_path, | |
| segmentation_max_size, | |
| segmentation_confidence_threshold, | |
| segmentation_line_percentage_threshold, | |
| segmentation_region_percentage_threshold, | |
| segmentation_line_iou, | |
| segmentation_region_iou, | |
| segmentation_line_overlap_threshold, | |
| segmentation_region_overlap_threshold | |
| ) | |
| self.recognizer = self._init_recognizer(recognition_model_path, recognition_batch_size) | |
| self.plotter = PlotHTR() | |
| if self.segmenter is None or self.recognizer is None: | |
| raise RuntimeError("Failed to initialize HTR pipeline components") | |
| def _init_segmenter(self, | |
| model_path: str, | |
| max_size: int, | |
| segmentation_confidence_threshold: float, | |
| segmentation_line_percentage_threshold: float, | |
| segmentation_region_percentage_threshold: float, | |
| segmentation_line_iou: float, | |
| segmentation_region_iou: float, | |
| segmentation_line_overlap_threshold: float, | |
| segmentation_region_overlap_threshold: float | |
| ) -> Optional[SegmentImage]: | |
| """ | |
| Initialize document segmentation model. | |
| Args: | |
| model_path: Path to segmentation model | |
| max_size: Maximum dimension for image preprocessing | |
| segmentation_confidence_threshold: Minimum confidence score for detections | |
| segmentation_line_percentage_threshold: Minimum polygon area as fraction of image area for lines | |
| segmentation_region_percentage_threshold: Minimum polygon area as fraction of image area for regions | |
| segmentation_line_iou: IoU threshold for merging overlapping line polygons | |
| segmentation_region_iou: IoU threshold for merging overlapping region polygons | |
| segmentation_line_overlap_threshold: Area overlap ratio threshold for merging lines | |
| segmentation_region_overlap_threshold: Area overlap ratio threshold for merging regions | |
| Returns: | |
| Initialized SegmentImage instance or None if initialization fails | |
| """ | |
| try: | |
| segmenter = SegmentImage( | |
| model_path=model_path, | |
| max_size=max_size, | |
| confidence_threshold=segmentation_confidence_threshold, | |
| line_percentage_threshold=segmentation_line_percentage_threshold, | |
| region_percentage_threshold=segmentation_region_percentage_threshold, | |
| line_iou=segmentation_line_iou, | |
| region_iou=segmentation_region_iou, | |
| line_overlap_threshold=segmentation_line_overlap_threshold, | |
| region_overlap_threshold=segmentation_region_overlap_threshold | |
| ) | |
| logger.info("β Segmentation model initialized") | |
| return segmenter | |
| except Exception as e: | |
| logger.error(f"Failed to initialize segmentation model: {e}") | |
| return None | |
| def _init_recognizer(self, model_path: str, batch_size: int) -> Optional[TextRecognition]: | |
| """ | |
| Initialize text recognition model. | |
| Args: | |
| model_path: Path to recognition model directory | |
| batch_size: Number of text lines to process in parallel | |
| Returns: | |
| Initialized TextRecognition instance or None if initialization fails | |
| """ | |
| try: | |
| recognizer = TextRecognition( | |
| model_path=model_path, | |
| device='cuda:0' if torch.cuda.is_available() else 'cpu', | |
| batch_size=batch_size | |
| ) | |
| logger.info("β Text recognition model initialized") | |
| return recognizer | |
| except Exception as e: | |
| logger.error(f"Failed to initialize text recognition model: {e}") | |
| return None | |
| def _merge_lines(self, segment_predictions: list) -> list: | |
| """ | |
| Merge text lines from all regions into a single list. | |
| Args: | |
| segment_predictions: List of region dictionaries containing line data | |
| Returns: | |
| Flat list of all text line polygons | |
| """ | |
| return [line for region in segment_predictions for line in region.get('lines', [])] | |
| def process_image(self, image) -> Dict[str, Any]: | |
| """ | |
| Process a document image through the complete HTR pipeline. | |
| Args: | |
| image: PIL Image object or numpy array | |
| Returns: | |
| Dictionary containing: | |
| - success: bool indicating if processing succeeded | |
| - segment_predictions: List of detected regions and lines | |
| - text_predictions: List of recognized text strings | |
| - processing_time: Time taken in seconds | |
| - error: Error message if success is False | |
| """ | |
| start_time = time.time() | |
| result = { | |
| 'success': False, | |
| 'segment_predictions': None, | |
| 'text_predictions': None, | |
| 'processing_time': 0.0, | |
| 'error': None | |
| } | |
| try: | |
| # Convert PIL image to numpy if needed | |
| if not isinstance(image, np.ndarray): | |
| image = np.array(image.convert('RGB')) | |
| # Run segmentation | |
| segment_predictions = self.segmenter.get_segmentation(image) | |
| if not segment_predictions: | |
| result['error'] = "No text lines detected in the image" | |
| result['processing_time'] = time.time() - start_time | |
| return result | |
| logger.info("β Segmentation completed") | |
| # Extract all lines for recognition | |
| img_lines = self._merge_lines(segment_predictions) | |
| # Run text recognition | |
| text_predictions = self.recognizer.process_lines(img_lines, image) | |
| logger.info("β Text recognition completed") | |
| result['success'] = True | |
| result['segment_predictions'] = segment_predictions | |
| result['text_predictions'] = text_predictions | |
| except Exception as e: | |
| logger.error(f"Error during image processing: {e}", exc_info=True) | |
| result['error'] = str(e) | |
| finally: | |
| result['processing_time'] = time.time() - start_time | |
| return result | |
| def is_allowed_source(file_path: Optional[str]) -> bool: | |
| """ | |
| Check if a file path is from an allowed source. | |
| This security measure prevents processing of files from untrusted sources, | |
| limiting uploads to specific domains and temporary directories. | |
| Args: | |
| file_path: Path to the uploaded file | |
| Returns: | |
| True if source is allowed, False otherwise | |
| """ | |
| if not file_path: | |
| logger.warning("No file path provided") | |
| return False | |
| # Check if path starts with any allowed source | |
| is_allowed = any(file_path.startswith(source) for source in Config.ALLOWED_SOURCES) | |
| if not is_allowed: | |
| logger.warning(f"File path not allowed: {file_path}") | |
| return is_allowed | |
| async def extract_filepath_from_request(request: gr.Request) -> Optional[str]: | |
| """ | |
| Extract file path from Gradio request object. | |
| Args: | |
| request: Gradio Request object | |
| Returns: | |
| File path string or None if not found | |
| """ | |
| try: | |
| body = await request.body() | |
| if not body: | |
| return None | |
| body_str = body.decode('utf-8') | |
| body_json = json.loads(body_str) | |
| # Navigate through Gradio's request structure | |
| if 'data' in body_json and isinstance(body_json['data'], list): | |
| for item in body_json['data']: | |
| if isinstance(item, dict) and 'path' in item: | |
| file_path = item['path'] | |
| logger.info(f"Extracted file path: {file_path}") | |
| return file_path | |
| return None | |
| except json.JSONDecodeError: | |
| logger.warning("Request body is not valid JSON") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error extracting file path: {e}") | |
| return None | |
| # Initialize HTR pipeline | |
| try: | |
| pipeline = HTRPipeline( | |
| segmentation_model_path=SEGMENTATION_MODEL_PATH, | |
| recognition_model_path=TROCR_MODEL_PATH, | |
| segmentation_max_size=Config.SEGMENTATION_MAX_SIZE, | |
| recognition_batch_size=Config.RECOGNITION_BATCH_SIZE, | |
| segmentation_confidence_threshold = Config.SEGMENTATION_CONFIDENCE_THRESHOLD, | |
| segmentation_line_percentage_threshold = Config.SEGMENTATION_LINE_PRECENTAGE_THRESHOLD, | |
| segmentation_region_percentage_threshold = Config.SEGMENTATION_REGION_PRECENTAGE_THRESHOLD, | |
| segmentation_line_iou = Config.SEGMENTATION_LINE_IOU, | |
| segmentation_region_iou = Config.SEGMENTATION_REGION_IOU, | |
| segmentation_line_overlap_threshold = Config.SEGMENTATION_LINE_OVERLAP_THRESHOLD, | |
| segmentation_region_overlap_threshold = Config.SEGMENTATION_REGION_OVERLAP_THRESHOLD | |
| ) | |
| logger.info("β HTR Pipeline initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize HTR pipeline: {e}") | |
| raise | |
| def create_demo() -> gr.Blocks: | |
| """ | |
| Create and configure the Gradio demo interface. | |
| Returns: | |
| Configured Gradio Blocks interface | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Monochrome(), | |
| title="Multicentury HTR Demo" | |
| ) as demo: | |
| gr.Image("logo.png", | |
| width=200, | |
| height=100, | |
| show_label=False, | |
| show_download_button=False, | |
| show_fullscreen_button=False, | |
| container=False, | |
| interactive=False | |
| ) | |
| gr.Markdown("# π Multicentury Handwritten Text Recognition") | |
| with gr.Tabs(): | |
| # English documentation | |
| with gr.Tab("English"): | |
| gr.Markdown(""" | |
| ## About this demo | |
| This HTR (Handwritten Text Recognition) pipeline combines two machine learning models: | |
| 1. **Text Region & Line Detection**: Identifies text regions and individual lines in document images | |
| 2. **Handwritten Text Recognition**: Transcribes the detected text lines | |
| The models have been trained by the National Archives of Finland in autumn 2025 using handwritten documents | |
| from the 16th to 20th centuries. | |
| ### How to use | |
| 1. Upload an image in the **Text Content** tab | |
| 2. Click **Process Image** | |
| 3. View results: transcribed text, detected regions, and text lines | |
| ### To obtain best results | |
| - Use high-quality scans | |
| - Ensure good contrast between text and background | |
| - Note that regular document layouts work best | |
| β οΈ **Note**: This is a demo application. 24/7 availability is not guaranteed. | |
| """) | |
| # Finnish documentation | |
| with gr.Tab("Suomeksi"): | |
| gr.Markdown(""" | |
| ## Tietoa demosta | |
| KÀsialantunnistusputki sisÀltÀÀ kaksi koneoppimismallia: | |
| 1. **Tekstialueiden ja -rivien tunnistus**: Tunnistaa tekstialueet ja yksittΓ€iset rivit dokumenttikuvista | |
| 2. **KΓ€sinkirjoitetun tekstin tunnistus**: Litteroi tunnistetut tekstirivit | |
| Mallit on koulutettu Kansallisarkistossa syksyllΓ€ 2025 kΓ€sinkirjoitetulla aineistolla, | |
| joka ajoittuu 1500-luvulta 1900-luvulle. | |
| ### KΓ€yttΓΆohje | |
| 1. Lataa kuva **Text Content** -vΓ€lilehdellΓ€ | |
| 2. Paina **Process Image** -painiketta | |
| 3. Tarkastele tuloksia: litteroitu teksti, tunnistetut alueet ja tekstirivit | |
| ### Parhaat tulokset saat kun | |
| - KΓ€ytΓ€t korkealaatuisia skannauksia | |
| - Varmistat hyvΓ€n kontrastin tekstin ja taustan vΓ€lillΓ€ | |
| - Huomioit ettΓ€ monimutkaiset rakenteet (esim. taulukot) voivat vaikeuttaa tunnistusta | |
| β οΈ **Huom**: TΓ€mΓ€ on demosovellus. YmpΓ€rivuorokautista toimivuutta ei luvata. | |
| """) | |
| gr.Markdown("---") | |
| with gr.Tabs(): | |
| with gr.Tab("π Text Content"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_img = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| with gr.Row(): | |
| process_btn = gr.Button( | |
| "π Process Image", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| clear_btn = gr.ClearButton( | |
| components=[input_img], | |
| value="ποΈ Clear" | |
| ) | |
| with gr.Column(scale=1): | |
| textbox = gr.Textbox( | |
| label="Recognized Text", | |
| lines=15, | |
| max_lines=30, | |
| show_copy_button=True, | |
| placeholder="Processed text will appear here..." | |
| ) | |
| download_text_file = gr.File( | |
| label="πΎ Download Text", | |
| visible=False, | |
| interactive=False | |
| ) | |
| processing_time = gr.Markdown( | |
| "", | |
| elem_classes="processing-time" | |
| ) | |
| status_message = gr.Markdown( | |
| "", | |
| elem_classes="error-message" | |
| ) | |
| with gr.Tab("πΊοΈ Text Regions"): | |
| region_img = gr.Image( | |
| label="Detected Text Regions", | |
| type="numpy", | |
| height=500 | |
| ) | |
| region_info = gr.Markdown("Upload and process an image to see detected regions") | |
| with gr.Tab("π Text Lines"): | |
| line_img = gr.Image( | |
| label="Detected Text Lines", | |
| type="numpy", | |
| height=500 | |
| ) | |
| line_info = gr.Markdown("Upload and process an image to see detected text lines") | |
| async def process_pipeline(image, request: gr.Request): | |
| """ | |
| Main processing function for the Gradio interface. | |
| Validates input, checks file source, runs HTR pipeline, and formats results. | |
| """ | |
| # Reset outputs | |
| outputs = { | |
| region_img: None, | |
| line_img: None, | |
| textbox: "", | |
| processing_time: "", | |
| status_message: "", | |
| download_text_file: gr.update(visible=False, value=None), | |
| region_info: "", | |
| line_info: "" | |
| } | |
| # Check file source (security measure) | |
| if request: | |
| file_path = await extract_filepath_from_request(request) | |
| if file_path and not is_allowed_source(file_path): | |
| outputs[status_message] = "β **Error**: File source not allowed for security reasons" | |
| yield tuple(outputs.values()) | |
| return | |
| # Show processing status | |
| outputs[status_message] = "β³ Processing image..." | |
| yield tuple(outputs.values()) | |
| # Run HTR pipeline | |
| result = pipeline.process_image(image) | |
| # Format processing time | |
| time_str = f"β±οΈ Processing time: {result['processing_time']:.2f}s" | |
| outputs[processing_time] = time_str | |
| if not result['success']: | |
| error = result['error'] or "Unknown error occurred" | |
| outputs[status_message] = f"β **Error**: {error}" | |
| yield tuple(outputs.values()) | |
| return | |
| # Process successful results | |
| try: | |
| segment_predictions = result['segment_predictions'] | |
| text_predictions = result['text_predictions'] | |
| # Generate visualizations | |
| region_plot = pipeline.plotter.plot_regions(segment_predictions, image) | |
| line_plot = pipeline.plotter.plot_lines(segment_predictions, image) | |
| # Format text output | |
| recognized_text = "\n".join(text_predictions) if text_predictions else "" | |
| # Update outputs | |
| outputs[region_img] = region_plot | |
| outputs[line_img] = line_plot | |
| outputs[textbox] = recognized_text | |
| outputs[status_message] = f"Recognized {len(text_predictions)} text lines" | |
| ## Create downloadable text file if text was recognized | |
| if recognized_text: | |
| # Create temporary file with proper filename | |
| temp_dir = tempfile.gettempdir() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"htr_result_{timestamp}.txt" | |
| filepath = os.path.join(temp_dir, filename) | |
| # Write text to file | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| f.write(recognized_text) | |
| outputs[download_text_file] = gr.update(visible=True, value=filepath) | |
| # Update info sections | |
| num_regions = len(segment_predictions) | |
| outputs[region_info] = f"Detected **{num_regions}** text region(s)" | |
| outputs[line_info] = f"Detected **{len(text_predictions)}** text line(s)" | |
| except Exception as e: | |
| logger.error(f"Error formatting results: {e}", exc_info=True) | |
| outputs[status_message] = f"β **Error**: Failed to format results - {e}" | |
| yield tuple(outputs.values()) | |
| # Connect button to processing function | |
| process_btn.click( | |
| fn=process_pipeline, | |
| inputs=[input_img], | |
| outputs=[ | |
| region_img, | |
| line_img, | |
| textbox, | |
| processing_time, | |
| status_message, | |
| download_text_file, | |
| region_info, | |
| line_info | |
| ], | |
| api_name=False # Disable API endpoint for security | |
| ) | |
| return demo | |
| # Create and launch demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue( | |
| max_size=30, # 30 users can queue without being rejected | |
| default_concurrency_limit=1 # Only one image processes at a time | |
| ) | |
| demo.launch( | |
| show_error=True, | |
| max_threads=2 # Minimal threads: 1 for processing + 1 for queue management | |
| ) |