| """PDF page classifier for production inference.""" |
|
|
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| import numpy as np |
| import numpy.typing as npt |
|
|
| try: |
| from .base_classifier import _BasePDFPageClassifier |
| except ImportError: |
| from base_classifier import _BasePDFPageClassifier |
|
|
| try: |
| import onnxruntime as ort |
| except ImportError as _e: |
| raise ImportError( |
| "onnxruntime is required for inference.\n" |
| "Install with: pip install onnxruntime" |
| ) from _e |
|
|
|
|
|
|
| class PDFPageClassifierONNX(_BasePDFPageClassifier): |
| """Classify PDF pages using a deployed ONNX model. |
| |
| Loads a self-contained deployment directory produced by |
| ``export_onnx.save_for_deployment`` and exposes a simple ``predict`` |
| interface. All preprocessing (center-crop, resize, normalization) is |
| performed in pure PIL + numpy, matching the pipeline used during training. |
| |
| Example:: |
| |
| clf = PDFPageClassifier.from_pretrained("outputs/run-42/deployment") |
| result = clf.predict("page_001.png") |
| print(result["needs_image_embedding"], result["predicted_classes"]) |
| """ |
|
|
| def __init__(self, model_path: str, config: dict[str, Any]) -> None: |
| """Initialise the classifier. |
| |
| Args: |
| model_path: Path to the ONNX model file. |
| config: Deployment config dict (same schema as config.json written |
| by save_for_deployment). |
| """ |
| super().__init__(config) |
| self._session = ort.InferenceSession(model_path) |
| self._input_name: str = self._session.get_inputs()[0].name |
|
|
| @classmethod |
| def from_pretrained(cls, model_dir: str) -> "PDFPageClassifier": |
| """Load a classifier from a deployment directory. |
| |
| The directory must contain: |
| - ``model.onnx`` — exported by save_for_deployment |
| - ``config.json`` — written by save_for_deployment |
| |
| Args: |
| model_dir: Path to the deployment directory. |
| |
| Returns: |
| Initialised PDFPageClassifier. |
| """ |
| path = Path(model_dir) |
| config_path = path / "config.json" |
|
|
| if not config_path.exists(): |
| raise FileNotFoundError(f"config.json not found in {model_dir}") |
|
|
| |
| candidates = ["model_int8.onnx", "model.onnx"] |
| for candidate in candidates: |
| if (path / candidate).exists(): |
| model_path = path / candidate |
| break |
| else: |
| raise FileNotFoundError( |
| f"No ONNX model found in {model_dir}. " |
| f"Expected one of: {', '.join(candidates)}." |
| ) |
|
|
| with open(config_path, encoding="utf-8") as f: |
| config = json.load(f) |
|
|
| return cls(str(model_path), config) |
|
|
| def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]": |
| return self._session.run(None, {self._input_name: batch_input})[0] |
|
|