|
|
from __future__ import annotations |
|
|
|
|
|
import json |
|
|
import os |
|
|
import warnings |
|
|
from typing import List, Optional |
|
|
|
|
|
try: |
|
|
import onnxruntime as ort |
|
|
except Exception: |
|
|
ort = None |
|
|
|
|
|
from .schema_utils import collect_enums, collect_formats |
|
|
from .types import Prediction |
|
|
|
|
|
|
|
|
class SemanticReasoner: |
|
|
def __init__(self, backend: str = "local", onnx_path: Optional[str] = None) -> None: |
|
|
self.backend = backend |
|
|
self.onnx_path = onnx_path or os.getenv("SEMVAL_ONNX_PATH") |
|
|
self._session = None |
|
|
if backend == "onnx": |
|
|
if ort is None: |
|
|
warnings.warn( |
|
|
"ONNX backend requested but onnxruntime is not installed; falling back to local heuristics.", |
|
|
stacklevel=2, |
|
|
) |
|
|
elif not self.onnx_path: |
|
|
warnings.warn( |
|
|
"ONNX backend requested but SEMVAL_ONNX_PATH is not set; falling back to local heuristics.", |
|
|
stacklevel=2, |
|
|
) |
|
|
else: |
|
|
try: |
|
|
self._session = ort.InferenceSession(self.onnx_path) |
|
|
except Exception as e: |
|
|
warnings.warn( |
|
|
f"Failed to initialize ONNXRuntime session ({e}); falling back to local heuristics.", |
|
|
stacklevel=2, |
|
|
) |
|
|
self._session = None |
|
|
|
|
|
def predict( |
|
|
self, schema_str: str, json_str: str, rule_errors: List[dict] |
|
|
) -> List[Prediction]: |
|
|
if not rule_errors: |
|
|
return [] |
|
|
try: |
|
|
schema = json.loads(schema_str) |
|
|
json.loads(json_str) |
|
|
except Exception: |
|
|
schema = {} |
|
|
|
|
|
enum_map = collect_enums(schema) |
|
|
fmt_map = collect_formats(schema) |
|
|
|
|
|
predictions: List[Prediction] = [] |
|
|
|
|
|
for err in rule_errors: |
|
|
jsonpath = err.get("jsonpath", "$") |
|
|
validator = err.get("validator") |
|
|
message = err.get("message", "") |
|
|
|
|
|
if validator == "type": |
|
|
|
|
|
if any(k in fmt_map for k in [jsonpath.replace("$", "")]): |
|
|
predictions.append( |
|
|
{ |
|
|
"error_type": "invalid_date", |
|
|
"jsonpath": jsonpath, |
|
|
"fix_action": "parse_date_iso", |
|
|
} |
|
|
) |
|
|
elif "boolean" in message: |
|
|
predictions.append( |
|
|
{ |
|
|
"error_type": "boolean_text", |
|
|
"jsonpath": jsonpath, |
|
|
"fix_action": "cast_bool", |
|
|
} |
|
|
) |
|
|
elif "integer" in message or "number" in message: |
|
|
predictions.append( |
|
|
{ |
|
|
"error_type": "number_text", |
|
|
"jsonpath": jsonpath, |
|
|
"fix_action": "cast_number", |
|
|
} |
|
|
) |
|
|
elif validator == "format": |
|
|
|
|
|
val = err.get("validator_value") |
|
|
if str(val) == "date": |
|
|
predictions.append( |
|
|
{ |
|
|
"error_type": "invalid_date", |
|
|
"jsonpath": jsonpath, |
|
|
"fix_action": "parse_date_iso", |
|
|
} |
|
|
) |
|
|
else: |
|
|
predictions.append( |
|
|
{ |
|
|
"error_type": "wrong_type", |
|
|
"jsonpath": jsonpath, |
|
|
"fix_action": "fill_default", |
|
|
} |
|
|
) |
|
|
elif validator == "enum": |
|
|
allowed = enum_map.get(jsonpath.replace("$", ""), []) |
|
|
predictions.append( |
|
|
{ |
|
|
"error_type": "enum_near_miss", |
|
|
"jsonpath": jsonpath, |
|
|
"fix_action": "map_enum", |
|
|
"fix_value": ",".join(allowed) if allowed else None, |
|
|
} |
|
|
) |
|
|
elif validator == "required": |
|
|
|
|
|
predictions.append( |
|
|
{ |
|
|
"error_type": "alias_key", |
|
|
"jsonpath": jsonpath, |
|
|
"fix_action": "rename_key", |
|
|
} |
|
|
) |
|
|
elif validator == "minimum" or validator == "maximum": |
|
|
|
|
|
continue |
|
|
else: |
|
|
|
|
|
predictions.append( |
|
|
{ |
|
|
"error_type": "wrong_type", |
|
|
"jsonpath": jsonpath, |
|
|
"fix_action": "fill_default", |
|
|
} |
|
|
) |
|
|
|
|
|
return predictions |
|
|
|