"""PDF page classifier for production inference.""" import json from pathlib import Path from typing import Any import numpy as np import numpy.typing as npt try: from .base_classifier import _BasePDFPageClassifier except ImportError: from base_classifier import _BasePDFPageClassifier # standalone / HF usage try: import onnxruntime as ort except ImportError as _e: raise ImportError( "onnxruntime is required for inference.\n" "Install with: pip install onnxruntime" ) from _e class PDFPageClassifierONNX(_BasePDFPageClassifier): """Classify PDF pages using a deployed ONNX model. Loads a self-contained deployment directory produced by ``export_onnx.save_for_deployment`` and exposes a simple ``predict`` interface. All preprocessing (center-crop, resize, normalization) is performed in pure PIL + numpy, matching the pipeline used during training. Example:: clf = PDFPageClassifier.from_pretrained("outputs/run-42/deployment") result = clf.predict("page_001.png") print(result["needs_image_embedding"], result["predicted_classes"]) """ def __init__(self, model_path: str, config: dict[str, Any]) -> None: """Initialise the classifier. Args: model_path: Path to the ONNX model file. config: Deployment config dict (same schema as config.json written by save_for_deployment). """ super().__init__(config) self._session = ort.InferenceSession(model_path) self._input_name: str = self._session.get_inputs()[0].name @classmethod def from_pretrained(cls, model_dir: str) -> "PDFPageClassifier": """Load a classifier from a deployment directory. The directory must contain: - ``model.onnx`` — exported by save_for_deployment - ``config.json`` — written by save_for_deployment Args: model_dir: Path to the deployment directory. Returns: Initialised PDFPageClassifier. """ path = Path(model_dir) config_path = path / "config.json" if not config_path.exists(): raise FileNotFoundError(f"config.json not found in {model_dir}") # Prefer INT8 (QAT export) over FP32 when both are present candidates = ["model_int8.onnx", "model.onnx"] for candidate in candidates: if (path / candidate).exists(): model_path = path / candidate break else: raise FileNotFoundError( f"No ONNX model found in {model_dir}. " f"Expected one of: {', '.join(candidates)}." ) with open(config_path, encoding="utf-8") as f: config = json.load(f) return cls(str(model_path), config) def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]": return self._session.run(None, {self._input_name: batch_input})[0]