from huggingface_hub import login, snapshot_download, hf_hub_download from typing import Optional, Tuple, Dict, Any from transformers import TrOCRProcessor from datetime import datetime from pathlib import Path import gradio as gr import numpy as np import onnxruntime import tempfile import logging import torch import time import json import os from plotting_functions import PlotHTR from segment_image import SegmentImage from onnx_text_recognition import TextRecognition # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler() # Explicit stdout handler for HF Spaces ] ) logger = logging.getLogger(__name__) # Log startup info for debugging in HF Spaces logger.info("="*50) logger.info("HTR Application Starting") logger.info(f"Python version: {os.sys.version}") logger.info(f"Running on Hugging Face Spaces: {os.getenv('SPACE_ID', 'Local')}") logger.info("="*50) # Configuration from environment variables class Config: """Application configuration from environment variables.""" HF_TOKEN = os.getenv("HF_TOKEN") SEGMENTATION_MAX_SIZE = 768 RECOGNITION_BATCH_SIZE = 10 SEGMENTATION_CONFIDENCE_THRESHOLD = 0.15 SEGMENTATION_LINE_PRECENTAGE_THRESHOLD = 7e-05 SEGMENTATION_REGION_PRECENTAGE_THRESHOLD = 7e-05 SEGMENTATION_LINE_IOU = 0.3 SEGMENTATION_REGION_IOU = 0.3 SEGMENTATION_LINE_OVERLAP_THRESHOLD = 0.5 SEGMENTATION_REGION_OVERLAP_THRESHOLD = 0.5 ALLOWED_SOURCES = ("https://astia.narc.fi, /tmp/gradio") # Model paths TROCR_MODEL_REPO = "Kansallisarkisto/multicentury-htr-model-small-onnx" SEGMENTATION_MODEL_REPO = "Kansallisarkisto/rfdetr_textline_textregion_detection_model" SEGMENTATION_MODEL_FILE = "rfdetr_text_seg_model_202510.pth" # Login to HuggingFace if token is available if Config.HF_TOKEN: try: login(token=Config.HF_TOKEN, add_to_git_credential=True) logger.info("✓ Logged in to HuggingFace") except Exception as e: logger.warning(f"Failed to login to HuggingFace: {e}") def download_models() -> Tuple[str, str]: """ Download required models from HuggingFace Hub. Returns: Tuple of (text_recognition_model_path, segmentation_model_path) Raises: RuntimeError: If model download fails """ try: logger.info("Downloading text recognition model...") trocr_path = snapshot_download(repo_id=Config.TROCR_MODEL_REPO) logger.info(f"✓ Text recognition model downloaded to {trocr_path}") logger.info("Downloading segmentation model...") seg_path = hf_hub_download( repo_id=Config.SEGMENTATION_MODEL_REPO, filename=Config.SEGMENTATION_MODEL_FILE ) logger.info(f"✓ Segmentation model downloaded to {seg_path}") return trocr_path, seg_path except Exception as e: logger.error(f"Failed to download models: {e}") raise RuntimeError(f"Model download failed: {e}") # Download models TROCR_MODEL_PATH, SEGMENTATION_MODEL_PATH = download_models() # Log CUDA availability logger.info(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): logger.info(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") class HTRPipeline: """ Handwritten Text Recognition pipeline combining segmentation and recognition. This class manages the initialization and execution of document segmentation and text recognition models. """ def __init__(self, segmentation_model_path: str, recognition_model_path: str, segmentation_max_size: int = 768, recognition_batch_size: int = 10, segmentation_confidence_threshold: float = 0.15, segmentation_line_percentage_threshold: float = 7e-05, segmentation_region_percentage_threshold: float = 7e-05, segmentation_line_iou: float = 0.3, segmentation_region_iou: float = 0.3, segmentation_line_overlap_threshold: float = 0.5, segmentation_region_overlap_threshold: float = 0.5 ): """ Initialize HTR pipeline with segmentation and recognition models. Args: segmentation_model_path: Path to segmentation model weights recognition_model_path: Path to recognition model directory segmentation_max_size: Maximum image dimension for segmentation recognition_batch_size: Batch size for text recognition segmentation_confidence_threshold: Minimum confidence score for detections segmentation_line_percentage_threshold: Minimum polygon area as fraction of image area for lines segmentation_region_percentage_threshold: Minimum polygon area as fraction of image area for regions segmentation_line_iou: IoU threshold for merging overlapping line polygons segmentation_region_iou: IoU threshold for merging overlapping region polygons segmentation_line_overlap_threshold: Area overlap ratio threshold for merging lines segmentation_region_overlap_threshold: Area overlap ratio threshold for merging regions """ self.segmenter = self._init_segmenter(segmentation_model_path, segmentation_max_size, segmentation_confidence_threshold, segmentation_line_percentage_threshold, segmentation_region_percentage_threshold, segmentation_line_iou, segmentation_region_iou, segmentation_line_overlap_threshold, segmentation_region_overlap_threshold ) self.recognizer = self._init_recognizer(recognition_model_path, recognition_batch_size) self.plotter = PlotHTR() if self.segmenter is None or self.recognizer is None: raise RuntimeError("Failed to initialize HTR pipeline components") def _init_segmenter(self, model_path: str, max_size: int, segmentation_confidence_threshold: float, segmentation_line_percentage_threshold: float, segmentation_region_percentage_threshold: float, segmentation_line_iou: float, segmentation_region_iou: float, segmentation_line_overlap_threshold: float, segmentation_region_overlap_threshold: float ) -> Optional[SegmentImage]: """ Initialize document segmentation model. Args: model_path: Path to segmentation model max_size: Maximum dimension for image preprocessing segmentation_confidence_threshold: Minimum confidence score for detections segmentation_line_percentage_threshold: Minimum polygon area as fraction of image area for lines segmentation_region_percentage_threshold: Minimum polygon area as fraction of image area for regions segmentation_line_iou: IoU threshold for merging overlapping line polygons segmentation_region_iou: IoU threshold for merging overlapping region polygons segmentation_line_overlap_threshold: Area overlap ratio threshold for merging lines segmentation_region_overlap_threshold: Area overlap ratio threshold for merging regions Returns: Initialized SegmentImage instance or None if initialization fails """ try: segmenter = SegmentImage( model_path=model_path, max_size=max_size, confidence_threshold=segmentation_confidence_threshold, line_percentage_threshold=segmentation_line_percentage_threshold, region_percentage_threshold=segmentation_region_percentage_threshold, line_iou=segmentation_line_iou, region_iou=segmentation_region_iou, line_overlap_threshold=segmentation_line_overlap_threshold, region_overlap_threshold=segmentation_region_overlap_threshold ) logger.info("✓ Segmentation model initialized") return segmenter except Exception as e: logger.error(f"Failed to initialize segmentation model: {e}") return None def _init_recognizer(self, model_path: str, batch_size: int) -> Optional[TextRecognition]: """ Initialize text recognition model. Args: model_path: Path to recognition model directory batch_size: Number of text lines to process in parallel Returns: Initialized TextRecognition instance or None if initialization fails """ try: recognizer = TextRecognition( model_path=model_path, device='cuda:0' if torch.cuda.is_available() else 'cpu', batch_size=batch_size ) logger.info("✓ Text recognition model initialized") return recognizer except Exception as e: logger.error(f"Failed to initialize text recognition model: {e}") return None def _merge_lines(self, segment_predictions: list) -> list: """ Merge text lines from all regions into a single list. Args: segment_predictions: List of region dictionaries containing line data Returns: Flat list of all text line polygons """ return [line for region in segment_predictions for line in region.get('lines', [])] def process_image(self, image) -> Dict[str, Any]: """ Process a document image through the complete HTR pipeline. Args: image: PIL Image object or numpy array Returns: Dictionary containing: - success: bool indicating if processing succeeded - segment_predictions: List of detected regions and lines - text_predictions: List of recognized text strings - processing_time: Time taken in seconds - error: Error message if success is False """ start_time = time.time() result = { 'success': False, 'segment_predictions': None, 'text_predictions': None, 'processing_time': 0.0, 'error': None } try: # Convert PIL image to numpy if needed if not isinstance(image, np.ndarray): image = np.array(image.convert('RGB')) # Run segmentation segment_predictions = self.segmenter.get_segmentation(image) if not segment_predictions: result['error'] = "No text lines detected in the image" result['processing_time'] = time.time() - start_time return result logger.info("✓ Segmentation completed") # Extract all lines for recognition img_lines = self._merge_lines(segment_predictions) # Run text recognition text_predictions = self.recognizer.process_lines(img_lines, image) logger.info("✓ Text recognition completed") result['success'] = True result['segment_predictions'] = segment_predictions result['text_predictions'] = text_predictions except Exception as e: logger.error(f"Error during image processing: {e}", exc_info=True) result['error'] = str(e) finally: result['processing_time'] = time.time() - start_time return result def is_allowed_source(file_path: Optional[str]) -> bool: """ Check if a file path is from an allowed source. This security measure prevents processing of files from untrusted sources, limiting uploads to specific domains and temporary directories. Args: file_path: Path to the uploaded file Returns: True if source is allowed, False otherwise """ if not file_path: logger.warning("No file path provided") return False # Check if path starts with any allowed source is_allowed = any(file_path.startswith(source) for source in Config.ALLOWED_SOURCES) if not is_allowed: logger.warning(f"File path not allowed: {file_path}") return is_allowed async def extract_filepath_from_request(request: gr.Request) -> Optional[str]: """ Extract file path from Gradio request object. Args: request: Gradio Request object Returns: File path string or None if not found """ try: body = await request.body() if not body: return None body_str = body.decode('utf-8') body_json = json.loads(body_str) # Navigate through Gradio's request structure if 'data' in body_json and isinstance(body_json['data'], list): for item in body_json['data']: if isinstance(item, dict) and 'path' in item: file_path = item['path'] logger.info(f"Extracted file path: {file_path}") return file_path return None except json.JSONDecodeError: logger.warning("Request body is not valid JSON") return None except Exception as e: logger.error(f"Error extracting file path: {e}") return None # Initialize HTR pipeline try: pipeline = HTRPipeline( segmentation_model_path=SEGMENTATION_MODEL_PATH, recognition_model_path=TROCR_MODEL_PATH, segmentation_max_size=Config.SEGMENTATION_MAX_SIZE, recognition_batch_size=Config.RECOGNITION_BATCH_SIZE, segmentation_confidence_threshold = Config.SEGMENTATION_CONFIDENCE_THRESHOLD, segmentation_line_percentage_threshold = Config.SEGMENTATION_LINE_PRECENTAGE_THRESHOLD, segmentation_region_percentage_threshold = Config.SEGMENTATION_REGION_PRECENTAGE_THRESHOLD, segmentation_line_iou = Config.SEGMENTATION_LINE_IOU, segmentation_region_iou = Config.SEGMENTATION_REGION_IOU, segmentation_line_overlap_threshold = Config.SEGMENTATION_LINE_OVERLAP_THRESHOLD, segmentation_region_overlap_threshold = Config.SEGMENTATION_REGION_OVERLAP_THRESHOLD ) logger.info("✓ HTR Pipeline initialized successfully") except Exception as e: logger.error(f"Failed to initialize HTR pipeline: {e}") raise def create_demo() -> gr.Blocks: """ Create and configure the Gradio demo interface. Returns: Configured Gradio Blocks interface """ with gr.Blocks( theme=gr.themes.Monochrome(), title="Multicentury HTR Demo" ) as demo: gr.Image("logo.png", width=200, height=100, show_label=False, show_download_button=False, show_fullscreen_button=False, container=False, interactive=False ) gr.Markdown("# 📜 Multicentury Handwritten Text Recognition") with gr.Tabs(): # English documentation with gr.Tab("English"): gr.Markdown(""" ## About this demo This HTR (Handwritten Text Recognition) pipeline combines two machine learning models: 1. **Text Region & Line Detection**: Identifies text regions and individual lines in document images 2. **Handwritten Text Recognition**: Transcribes the detected text lines The models have been trained by the National Archives of Finland in autumn 2025 using handwritten documents from the 16th to 20th centuries. ### How to use 1. Upload an image in the **Text Content** tab 2. Click **Process Image** 3. View results: transcribed text, detected regions, and text lines ### To obtain best results - Use high-quality scans - Ensure good contrast between text and background - Note that regular document layouts work best ⚠️ **Note**: This is a demo application. 24/7 availability is not guaranteed. """) # Finnish documentation with gr.Tab("Suomeksi"): gr.Markdown(""" ## Tietoa demosta Käsialantunnistusputki sisältää kaksi koneoppimismallia: 1. **Tekstialueiden ja -rivien tunnistus**: Tunnistaa tekstialueet ja yksittäiset rivit dokumenttikuvista 2. **Käsinkirjoitetun tekstin tunnistus**: Litteroi tunnistetut tekstirivit Mallit on koulutettu Kansallisarkistossa syksyllä 2025 käsinkirjoitetulla aineistolla, joka ajoittuu 1500-luvulta 1900-luvulle. ### Käyttöohje 1. Lataa kuva **Text Content** -välilehdellä 2. Paina **Process Image** -painiketta 3. Tarkastele tuloksia: litteroitu teksti, tunnistetut alueet ja tekstirivit ### Parhaat tulokset saat kun - Käytät korkealaatuisia skannauksia - Varmistat hyvän kontrastin tekstin ja taustan välillä - Huomioit että monimutkaiset rakenteet (esim. taulukot) voivat vaikeuttaa tunnistusta ⚠️ **Huom**: Tämä on demosovellus. Ympärivuorokautista toimivuutta ei luvata. """) gr.Markdown("---") with gr.Tabs(): with gr.Tab("📄 Text Content"): with gr.Row(): with gr.Column(scale=1): input_img = gr.Image( label="Input Image", type="pil", height=400 ) with gr.Row(): process_btn = gr.Button( "🚀 Process Image", variant="primary", size="lg" ) clear_btn = gr.ClearButton( components=[input_img], value="🗑️ Clear" ) with gr.Column(scale=1): textbox = gr.Textbox( label="Recognized Text", lines=15, max_lines=30, show_copy_button=True, placeholder="Processed text will appear here..." ) download_text_file = gr.File( label="💾 Download Text", visible=False, interactive=False ) processing_time = gr.Markdown( "", elem_classes="processing-time" ) status_message = gr.Markdown( "", elem_classes="error-message" ) with gr.Tab("🗺️ Text Regions"): region_img = gr.Image( label="Detected Text Regions", type="numpy", height=500 ) region_info = gr.Markdown("Upload and process an image to see detected regions") with gr.Tab("📝 Text Lines"): line_img = gr.Image( label="Detected Text Lines", type="numpy", height=500 ) line_info = gr.Markdown("Upload and process an image to see detected text lines") async def process_pipeline(image, request: gr.Request): """ Main processing function for the Gradio interface. Validates input, checks file source, runs HTR pipeline, and formats results. """ # Reset outputs outputs = { region_img: None, line_img: None, textbox: "", processing_time: "", status_message: "", download_text_file: gr.update(visible=False, value=None), region_info: "", line_info: "" } # Check file source (security measure) if request: file_path = await extract_filepath_from_request(request) if file_path and not is_allowed_source(file_path): outputs[status_message] = "❌ **Error**: File source not allowed for security reasons" yield tuple(outputs.values()) return # Show processing status outputs[status_message] = "⏳ Processing image..." yield tuple(outputs.values()) # Run HTR pipeline result = pipeline.process_image(image) # Format processing time time_str = f"⏱️ Processing time: {result['processing_time']:.2f}s" outputs[processing_time] = time_str if not result['success']: error = result['error'] or "Unknown error occurred" outputs[status_message] = f"❌ **Error**: {error}" yield tuple(outputs.values()) return # Process successful results try: segment_predictions = result['segment_predictions'] text_predictions = result['text_predictions'] # Generate visualizations region_plot = pipeline.plotter.plot_regions(segment_predictions, image) line_plot = pipeline.plotter.plot_lines(segment_predictions, image) # Format text output recognized_text = "\n".join(text_predictions) if text_predictions else "" # Update outputs outputs[region_img] = region_plot outputs[line_img] = line_plot outputs[textbox] = recognized_text outputs[status_message] = f"Recognized {len(text_predictions)} text lines" ## Create downloadable text file if text was recognized if recognized_text: # Create temporary file with proper filename temp_dir = tempfile.gettempdir() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"htr_result_{timestamp}.txt" filepath = os.path.join(temp_dir, filename) # Write text to file with open(filepath, 'w', encoding='utf-8') as f: f.write(recognized_text) outputs[download_text_file] = gr.update(visible=True, value=filepath) # Update info sections num_regions = len(segment_predictions) outputs[region_info] = f"Detected **{num_regions}** text region(s)" outputs[line_info] = f"Detected **{len(text_predictions)}** text line(s)" except Exception as e: logger.error(f"Error formatting results: {e}", exc_info=True) outputs[status_message] = f"❌ **Error**: Failed to format results - {e}" yield tuple(outputs.values()) # Connect button to processing function process_btn.click( fn=process_pipeline, inputs=[input_img], outputs=[ region_img, line_img, textbox, processing_time, status_message, download_text_file, region_info, line_info ], api_name=False # Disable API endpoint for security ) return demo # Create and launch demo if __name__ == "__main__": demo = create_demo() demo.queue( max_size=30, # 30 users can queue without being rejected default_concurrency_limit=1 # Only one image processes at a time ) demo.launch( show_error=True, max_threads=2 # Minimal threads: 1 for processing + 1 for queue management )