thearnabsarkar's picture
Upload json_semval/ml_model.py with huggingface_hub
a329232 verified
raw
history blame
5.28 kB
from __future__ import annotations
import json
import os
import warnings
from typing import List, Optional
try:
import onnxruntime as ort
except Exception: # pragma: no cover - optional
ort = None
from .schema_utils import collect_enums, collect_formats
from .types import Prediction
class SemanticReasoner:
def __init__(self, backend: str = "local", onnx_path: Optional[str] = None) -> None:
self.backend = backend
self.onnx_path = onnx_path or os.getenv("SEMVAL_ONNX_PATH")
self._session = None
if backend == "onnx":
if ort is None:
warnings.warn(
"ONNX backend requested but onnxruntime is not installed; falling back to local heuristics.",
stacklevel=2,
)
elif not self.onnx_path:
warnings.warn(
"ONNX backend requested but SEMVAL_ONNX_PATH is not set; falling back to local heuristics.",
stacklevel=2,
)
else:
try:
self._session = ort.InferenceSession(self.onnx_path)
except Exception as e:
warnings.warn(
f"Failed to initialize ONNXRuntime session ({e}); falling back to local heuristics.",
stacklevel=2,
)
self._session = None
def predict(
self, schema_str: str, json_str: str, rule_errors: List[dict]
) -> List[Prediction]:
if not rule_errors:
return []
try:
schema = json.loads(schema_str)
json.loads(json_str) # validate but don't store
except Exception:
schema = {}
enum_map = collect_enums(schema)
fmt_map = collect_formats(schema)
predictions: List[Prediction] = []
# Heuristic baseline: map error validator/message to a plausible fix
for err in rule_errors:
jsonpath = err.get("jsonpath", "$")
validator = err.get("validator")
message = err.get("message", "")
if validator == "type":
# guess bool vs number vs date
if any(k in fmt_map for k in [jsonpath.replace("$", "")]):
predictions.append(
{
"error_type": "invalid_date",
"jsonpath": jsonpath,
"fix_action": "parse_date_iso",
}
)
elif "boolean" in message:
predictions.append(
{
"error_type": "boolean_text",
"jsonpath": jsonpath,
"fix_action": "cast_bool",
}
)
elif "integer" in message or "number" in message:
predictions.append(
{
"error_type": "number_text",
"jsonpath": jsonpath,
"fix_action": "cast_number",
}
)
elif validator == "format":
# e.g., date format violations
val = err.get("validator_value")
if str(val) == "date":
predictions.append(
{
"error_type": "invalid_date",
"jsonpath": jsonpath,
"fix_action": "parse_date_iso",
}
)
else:
predictions.append(
{
"error_type": "wrong_type",
"jsonpath": jsonpath,
"fix_action": "fill_default",
}
)
elif validator == "enum":
allowed = enum_map.get(jsonpath.replace("$", ""), [])
predictions.append(
{
"error_type": "enum_near_miss",
"jsonpath": jsonpath,
"fix_action": "map_enum",
"fix_value": ",".join(allowed) if allowed else None,
}
)
elif validator == "required":
# When required property missing, try alias_key (heuristic)
predictions.append(
{
"error_type": "alias_key",
"jsonpath": jsonpath,
"fix_action": "rename_key",
}
)
elif validator == "minimum" or validator == "maximum":
# no direct fix; leave to rules
continue
else:
# fallback - no-op suggestion
predictions.append(
{
"error_type": "wrong_type",
"jsonpath": jsonpath,
"fix_action": "fill_default",
}
)
return predictions