MikkoLipsanen's picture
Update app.py
c0612a1 verified
from huggingface_hub import login, snapshot_download, hf_hub_download
from typing import Optional, Tuple, Dict, Any
from transformers import TrOCRProcessor
from datetime import datetime
from pathlib import Path
import gradio as gr
import numpy as np
import onnxruntime
import tempfile
import logging
import torch
import time
import json
import os
from plotting_functions import PlotHTR
from segment_image import SegmentImage
from onnx_text_recognition import TextRecognition
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler() # Explicit stdout handler for HF Spaces
]
)
logger = logging.getLogger(__name__)
# Log startup info for debugging in HF Spaces
logger.info("="*50)
logger.info("HTR Application Starting")
logger.info(f"Python version: {os.sys.version}")
logger.info(f"Running on Hugging Face Spaces: {os.getenv('SPACE_ID', 'Local')}")
logger.info("="*50)
# Configuration from environment variables
class Config:
"""Application configuration from environment variables."""
HF_TOKEN = os.getenv("HF_TOKEN")
SEGMENTATION_MAX_SIZE = 768
RECOGNITION_BATCH_SIZE = 10
SEGMENTATION_CONFIDENCE_THRESHOLD = 0.15
SEGMENTATION_LINE_PRECENTAGE_THRESHOLD = 7e-05
SEGMENTATION_REGION_PRECENTAGE_THRESHOLD = 7e-05
SEGMENTATION_LINE_IOU = 0.3
SEGMENTATION_REGION_IOU = 0.3
SEGMENTATION_LINE_OVERLAP_THRESHOLD = 0.5
SEGMENTATION_REGION_OVERLAP_THRESHOLD = 0.5
ALLOWED_SOURCES = ("https://astia.narc.fi, /tmp/gradio")
# Model paths
TROCR_MODEL_REPO = "Kansallisarkisto/multicentury-htr-model-small-onnx"
SEGMENTATION_MODEL_REPO = "Kansallisarkisto/rfdetr_textline_textregion_detection_model"
SEGMENTATION_MODEL_FILE = "rfdetr_text_seg_model_202510.pth"
# Login to HuggingFace if token is available
if Config.HF_TOKEN:
try:
login(token=Config.HF_TOKEN, add_to_git_credential=True)
logger.info("βœ“ Logged in to HuggingFace")
except Exception as e:
logger.warning(f"Failed to login to HuggingFace: {e}")
def download_models() -> Tuple[str, str]:
"""
Download required models from HuggingFace Hub.
Returns:
Tuple of (text_recognition_model_path, segmentation_model_path)
Raises:
RuntimeError: If model download fails
"""
try:
logger.info("Downloading text recognition model...")
trocr_path = snapshot_download(repo_id=Config.TROCR_MODEL_REPO)
logger.info(f"βœ“ Text recognition model downloaded to {trocr_path}")
logger.info("Downloading segmentation model...")
seg_path = hf_hub_download(
repo_id=Config.SEGMENTATION_MODEL_REPO,
filename=Config.SEGMENTATION_MODEL_FILE
)
logger.info(f"βœ“ Segmentation model downloaded to {seg_path}")
return trocr_path, seg_path
except Exception as e:
logger.error(f"Failed to download models: {e}")
raise RuntimeError(f"Model download failed: {e}")
# Download models
TROCR_MODEL_PATH, SEGMENTATION_MODEL_PATH = download_models()
# Log CUDA availability
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
class HTRPipeline:
"""
Handwritten Text Recognition pipeline combining segmentation and recognition.
This class manages the initialization and execution of document segmentation
and text recognition models.
"""
def __init__(self,
segmentation_model_path: str,
recognition_model_path: str,
segmentation_max_size: int = 768,
recognition_batch_size: int = 10,
segmentation_confidence_threshold: float = 0.15,
segmentation_line_percentage_threshold: float = 7e-05,
segmentation_region_percentage_threshold: float = 7e-05,
segmentation_line_iou: float = 0.3,
segmentation_region_iou: float = 0.3,
segmentation_line_overlap_threshold: float = 0.5,
segmentation_region_overlap_threshold: float = 0.5
):
"""
Initialize HTR pipeline with segmentation and recognition models.
Args:
segmentation_model_path: Path to segmentation model weights
recognition_model_path: Path to recognition model directory
segmentation_max_size: Maximum image dimension for segmentation
recognition_batch_size: Batch size for text recognition
segmentation_confidence_threshold: Minimum confidence score for detections
segmentation_line_percentage_threshold: Minimum polygon area as fraction of image area for lines
segmentation_region_percentage_threshold: Minimum polygon area as fraction of image area for regions
segmentation_line_iou: IoU threshold for merging overlapping line polygons
segmentation_region_iou: IoU threshold for merging overlapping region polygons
segmentation_line_overlap_threshold: Area overlap ratio threshold for merging lines
segmentation_region_overlap_threshold: Area overlap ratio threshold for merging regions
"""
self.segmenter = self._init_segmenter(segmentation_model_path,
segmentation_max_size,
segmentation_confidence_threshold,
segmentation_line_percentage_threshold,
segmentation_region_percentage_threshold,
segmentation_line_iou,
segmentation_region_iou,
segmentation_line_overlap_threshold,
segmentation_region_overlap_threshold
)
self.recognizer = self._init_recognizer(recognition_model_path, recognition_batch_size)
self.plotter = PlotHTR()
if self.segmenter is None or self.recognizer is None:
raise RuntimeError("Failed to initialize HTR pipeline components")
def _init_segmenter(self,
model_path: str,
max_size: int,
segmentation_confidence_threshold: float,
segmentation_line_percentage_threshold: float,
segmentation_region_percentage_threshold: float,
segmentation_line_iou: float,
segmentation_region_iou: float,
segmentation_line_overlap_threshold: float,
segmentation_region_overlap_threshold: float
) -> Optional[SegmentImage]:
"""
Initialize document segmentation model.
Args:
model_path: Path to segmentation model
max_size: Maximum dimension for image preprocessing
segmentation_confidence_threshold: Minimum confidence score for detections
segmentation_line_percentage_threshold: Minimum polygon area as fraction of image area for lines
segmentation_region_percentage_threshold: Minimum polygon area as fraction of image area for regions
segmentation_line_iou: IoU threshold for merging overlapping line polygons
segmentation_region_iou: IoU threshold for merging overlapping region polygons
segmentation_line_overlap_threshold: Area overlap ratio threshold for merging lines
segmentation_region_overlap_threshold: Area overlap ratio threshold for merging regions
Returns:
Initialized SegmentImage instance or None if initialization fails
"""
try:
segmenter = SegmentImage(
model_path=model_path,
max_size=max_size,
confidence_threshold=segmentation_confidence_threshold,
line_percentage_threshold=segmentation_line_percentage_threshold,
region_percentage_threshold=segmentation_region_percentage_threshold,
line_iou=segmentation_line_iou,
region_iou=segmentation_region_iou,
line_overlap_threshold=segmentation_line_overlap_threshold,
region_overlap_threshold=segmentation_region_overlap_threshold
)
logger.info("βœ“ Segmentation model initialized")
return segmenter
except Exception as e:
logger.error(f"Failed to initialize segmentation model: {e}")
return None
def _init_recognizer(self, model_path: str, batch_size: int) -> Optional[TextRecognition]:
"""
Initialize text recognition model.
Args:
model_path: Path to recognition model directory
batch_size: Number of text lines to process in parallel
Returns:
Initialized TextRecognition instance or None if initialization fails
"""
try:
recognizer = TextRecognition(
model_path=model_path,
device='cuda:0' if torch.cuda.is_available() else 'cpu',
batch_size=batch_size
)
logger.info("βœ“ Text recognition model initialized")
return recognizer
except Exception as e:
logger.error(f"Failed to initialize text recognition model: {e}")
return None
def _merge_lines(self, segment_predictions: list) -> list:
"""
Merge text lines from all regions into a single list.
Args:
segment_predictions: List of region dictionaries containing line data
Returns:
Flat list of all text line polygons
"""
return [line for region in segment_predictions for line in region.get('lines', [])]
def process_image(self, image) -> Dict[str, Any]:
"""
Process a document image through the complete HTR pipeline.
Args:
image: PIL Image object or numpy array
Returns:
Dictionary containing:
- success: bool indicating if processing succeeded
- segment_predictions: List of detected regions and lines
- text_predictions: List of recognized text strings
- processing_time: Time taken in seconds
- error: Error message if success is False
"""
start_time = time.time()
result = {
'success': False,
'segment_predictions': None,
'text_predictions': None,
'processing_time': 0.0,
'error': None
}
try:
# Convert PIL image to numpy if needed
if not isinstance(image, np.ndarray):
image = np.array(image.convert('RGB'))
# Run segmentation
segment_predictions = self.segmenter.get_segmentation(image)
if not segment_predictions:
result['error'] = "No text lines detected in the image"
result['processing_time'] = time.time() - start_time
return result
logger.info("βœ“ Segmentation completed")
# Extract all lines for recognition
img_lines = self._merge_lines(segment_predictions)
# Run text recognition
text_predictions = self.recognizer.process_lines(img_lines, image)
logger.info("βœ“ Text recognition completed")
result['success'] = True
result['segment_predictions'] = segment_predictions
result['text_predictions'] = text_predictions
except Exception as e:
logger.error(f"Error during image processing: {e}", exc_info=True)
result['error'] = str(e)
finally:
result['processing_time'] = time.time() - start_time
return result
def is_allowed_source(file_path: Optional[str]) -> bool:
"""
Check if a file path is from an allowed source.
This security measure prevents processing of files from untrusted sources,
limiting uploads to specific domains and temporary directories.
Args:
file_path: Path to the uploaded file
Returns:
True if source is allowed, False otherwise
"""
if not file_path:
logger.warning("No file path provided")
return False
# Check if path starts with any allowed source
is_allowed = any(file_path.startswith(source) for source in Config.ALLOWED_SOURCES)
if not is_allowed:
logger.warning(f"File path not allowed: {file_path}")
return is_allowed
async def extract_filepath_from_request(request: gr.Request) -> Optional[str]:
"""
Extract file path from Gradio request object.
Args:
request: Gradio Request object
Returns:
File path string or None if not found
"""
try:
body = await request.body()
if not body:
return None
body_str = body.decode('utf-8')
body_json = json.loads(body_str)
# Navigate through Gradio's request structure
if 'data' in body_json and isinstance(body_json['data'], list):
for item in body_json['data']:
if isinstance(item, dict) and 'path' in item:
file_path = item['path']
logger.info(f"Extracted file path: {file_path}")
return file_path
return None
except json.JSONDecodeError:
logger.warning("Request body is not valid JSON")
return None
except Exception as e:
logger.error(f"Error extracting file path: {e}")
return None
# Initialize HTR pipeline
try:
pipeline = HTRPipeline(
segmentation_model_path=SEGMENTATION_MODEL_PATH,
recognition_model_path=TROCR_MODEL_PATH,
segmentation_max_size=Config.SEGMENTATION_MAX_SIZE,
recognition_batch_size=Config.RECOGNITION_BATCH_SIZE,
segmentation_confidence_threshold = Config.SEGMENTATION_CONFIDENCE_THRESHOLD,
segmentation_line_percentage_threshold = Config.SEGMENTATION_LINE_PRECENTAGE_THRESHOLD,
segmentation_region_percentage_threshold = Config.SEGMENTATION_REGION_PRECENTAGE_THRESHOLD,
segmentation_line_iou = Config.SEGMENTATION_LINE_IOU,
segmentation_region_iou = Config.SEGMENTATION_REGION_IOU,
segmentation_line_overlap_threshold = Config.SEGMENTATION_LINE_OVERLAP_THRESHOLD,
segmentation_region_overlap_threshold = Config.SEGMENTATION_REGION_OVERLAP_THRESHOLD
)
logger.info("βœ“ HTR Pipeline initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize HTR pipeline: {e}")
raise
def create_demo() -> gr.Blocks:
"""
Create and configure the Gradio demo interface.
Returns:
Configured Gradio Blocks interface
"""
with gr.Blocks(
theme=gr.themes.Monochrome(),
title="Multicentury HTR Demo"
) as demo:
gr.Image("logo.png",
width=200,
height=100,
show_label=False,
show_download_button=False,
show_fullscreen_button=False,
container=False,
interactive=False
)
gr.Markdown("# πŸ“œ Multicentury Handwritten Text Recognition")
with gr.Tabs():
# English documentation
with gr.Tab("English"):
gr.Markdown("""
## About this demo
This HTR (Handwritten Text Recognition) pipeline combines two machine learning models:
1. **Text Region & Line Detection**: Identifies text regions and individual lines in document images
2. **Handwritten Text Recognition**: Transcribes the detected text lines
The models have been trained by the National Archives of Finland in autumn 2025 using handwritten documents
from the 16th to 20th centuries.
### How to use
1. Upload an image in the **Text Content** tab
2. Click **Process Image**
3. View results: transcribed text, detected regions, and text lines
### To obtain best results
- Use high-quality scans
- Ensure good contrast between text and background
- Note that regular document layouts work best
⚠️ **Note**: This is a demo application. 24/7 availability is not guaranteed.
""")
# Finnish documentation
with gr.Tab("Suomeksi"):
gr.Markdown("""
## Tietoa demosta
KÀsialantunnistusputki sisÀltÀÀ kaksi koneoppimismallia:
1. **Tekstialueiden ja -rivien tunnistus**: Tunnistaa tekstialueet ja yksittΓ€iset rivit dokumenttikuvista
2. **KΓ€sinkirjoitetun tekstin tunnistus**: Litteroi tunnistetut tekstirivit
Mallit on koulutettu Kansallisarkistossa syksyllΓ€ 2025 kΓ€sinkirjoitetulla aineistolla,
joka ajoittuu 1500-luvulta 1900-luvulle.
### KΓ€yttΓΆohje
1. Lataa kuva **Text Content** -vΓ€lilehdellΓ€
2. Paina **Process Image** -painiketta
3. Tarkastele tuloksia: litteroitu teksti, tunnistetut alueet ja tekstirivit
### Parhaat tulokset saat kun
- KΓ€ytΓ€t korkealaatuisia skannauksia
- Varmistat hyvΓ€n kontrastin tekstin ja taustan vΓ€lillΓ€
- Huomioit ettΓ€ monimutkaiset rakenteet (esim. taulukot) voivat vaikeuttaa tunnistusta
⚠️ **Huom**: TÀmÀ on demosovellus. YmpÀrivuorokautista toimivuutta ei luvata.
""")
gr.Markdown("---")
with gr.Tabs():
with gr.Tab("πŸ“„ Text Content"):
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(
label="Input Image",
type="pil",
height=400
)
with gr.Row():
process_btn = gr.Button(
"πŸš€ Process Image",
variant="primary",
size="lg"
)
clear_btn = gr.ClearButton(
components=[input_img],
value="πŸ—‘οΈ Clear"
)
with gr.Column(scale=1):
textbox = gr.Textbox(
label="Recognized Text",
lines=15,
max_lines=30,
show_copy_button=True,
placeholder="Processed text will appear here..."
)
download_text_file = gr.File(
label="πŸ’Ύ Download Text",
visible=False,
interactive=False
)
processing_time = gr.Markdown(
"",
elem_classes="processing-time"
)
status_message = gr.Markdown(
"",
elem_classes="error-message"
)
with gr.Tab("πŸ—ΊοΈ Text Regions"):
region_img = gr.Image(
label="Detected Text Regions",
type="numpy",
height=500
)
region_info = gr.Markdown("Upload and process an image to see detected regions")
with gr.Tab("πŸ“ Text Lines"):
line_img = gr.Image(
label="Detected Text Lines",
type="numpy",
height=500
)
line_info = gr.Markdown("Upload and process an image to see detected text lines")
async def process_pipeline(image, request: gr.Request):
"""
Main processing function for the Gradio interface.
Validates input, checks file source, runs HTR pipeline, and formats results.
"""
# Reset outputs
outputs = {
region_img: None,
line_img: None,
textbox: "",
processing_time: "",
status_message: "",
download_text_file: gr.update(visible=False, value=None),
region_info: "",
line_info: ""
}
# Check file source (security measure)
if request:
file_path = await extract_filepath_from_request(request)
if file_path and not is_allowed_source(file_path):
outputs[status_message] = "❌ **Error**: File source not allowed for security reasons"
yield tuple(outputs.values())
return
# Show processing status
outputs[status_message] = "⏳ Processing image..."
yield tuple(outputs.values())
# Run HTR pipeline
result = pipeline.process_image(image)
# Format processing time
time_str = f"⏱️ Processing time: {result['processing_time']:.2f}s"
outputs[processing_time] = time_str
if not result['success']:
error = result['error'] or "Unknown error occurred"
outputs[status_message] = f"❌ **Error**: {error}"
yield tuple(outputs.values())
return
# Process successful results
try:
segment_predictions = result['segment_predictions']
text_predictions = result['text_predictions']
# Generate visualizations
region_plot = pipeline.plotter.plot_regions(segment_predictions, image)
line_plot = pipeline.plotter.plot_lines(segment_predictions, image)
# Format text output
recognized_text = "\n".join(text_predictions) if text_predictions else ""
# Update outputs
outputs[region_img] = region_plot
outputs[line_img] = line_plot
outputs[textbox] = recognized_text
outputs[status_message] = f"Recognized {len(text_predictions)} text lines"
## Create downloadable text file if text was recognized
if recognized_text:
# Create temporary file with proper filename
temp_dir = tempfile.gettempdir()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"htr_result_{timestamp}.txt"
filepath = os.path.join(temp_dir, filename)
# Write text to file
with open(filepath, 'w', encoding='utf-8') as f:
f.write(recognized_text)
outputs[download_text_file] = gr.update(visible=True, value=filepath)
# Update info sections
num_regions = len(segment_predictions)
outputs[region_info] = f"Detected **{num_regions}** text region(s)"
outputs[line_info] = f"Detected **{len(text_predictions)}** text line(s)"
except Exception as e:
logger.error(f"Error formatting results: {e}", exc_info=True)
outputs[status_message] = f"❌ **Error**: Failed to format results - {e}"
yield tuple(outputs.values())
# Connect button to processing function
process_btn.click(
fn=process_pipeline,
inputs=[input_img],
outputs=[
region_img,
line_img,
textbox,
processing_time,
status_message,
download_text_file,
region_info,
line_info
],
api_name=False # Disable API endpoint for security
)
return demo
# Create and launch demo
if __name__ == "__main__":
demo = create_demo()
demo.queue(
max_size=30, # 30 users can queue without being rejected
default_concurrency_limit=1 # Only one image processes at a time
)
demo.launch(
show_error=True,
max_threads=2 # Minimal threads: 1 for processing + 1 for queue management
)