pdf-pages-classifier / classifiers /classifier_onnx.py
mciancone's picture
Upload model artifacts and classifier scripts
bd27421 verified
"""PDF page classifier for production inference."""
import json
from pathlib import Path
from typing import Any
import numpy as np
import numpy.typing as npt
try:
from .base_classifier import _BasePDFPageClassifier
except ImportError:
from base_classifier import _BasePDFPageClassifier # standalone / HF usage
try:
import onnxruntime as ort
except ImportError as _e:
raise ImportError(
"onnxruntime is required for inference.\n"
"Install with: pip install onnxruntime"
) from _e
class PDFPageClassifierONNX(_BasePDFPageClassifier):
"""Classify PDF pages using a deployed ONNX model.
Loads a self-contained deployment directory produced by
``export_onnx.save_for_deployment`` and exposes a simple ``predict``
interface. All preprocessing (center-crop, resize, normalization) is
performed in pure PIL + numpy, matching the pipeline used during training.
Example::
clf = PDFPageClassifier.from_pretrained("outputs/run-42/deployment")
result = clf.predict("page_001.png")
print(result["needs_image_embedding"], result["predicted_classes"])
"""
def __init__(self, model_path: str, config: dict[str, Any]) -> None:
"""Initialise the classifier.
Args:
model_path: Path to the ONNX model file.
config: Deployment config dict (same schema as config.json written
by save_for_deployment).
"""
super().__init__(config)
self._session = ort.InferenceSession(model_path)
self._input_name: str = self._session.get_inputs()[0].name
@classmethod
def from_pretrained(cls, model_dir: str) -> "PDFPageClassifier":
"""Load a classifier from a deployment directory.
The directory must contain:
- ``model.onnx`` — exported by save_for_deployment
- ``config.json`` — written by save_for_deployment
Args:
model_dir: Path to the deployment directory.
Returns:
Initialised PDFPageClassifier.
"""
path = Path(model_dir)
config_path = path / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"config.json not found in {model_dir}")
# Prefer INT8 (QAT export) over FP32 when both are present
candidates = ["model_int8.onnx", "model.onnx"]
for candidate in candidates:
if (path / candidate).exists():
model_path = path / candidate
break
else:
raise FileNotFoundError(
f"No ONNX model found in {model_dir}. "
f"Expected one of: {', '.join(candidates)}."
)
with open(config_path, encoding="utf-8") as f:
config = json.load(f)
return cls(str(model_path), config)
def _run_batch(self, batch_input: "npt.NDArray[np.float32]") -> "npt.NDArray[np.float32]":
return self._session.run(None, {self._input_name: batch_input})[0]