File size: 5,276 Bytes
a329232 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from __future__ import annotations
import json
import os
import warnings
from typing import List, Optional
try:
import onnxruntime as ort
except Exception: # pragma: no cover - optional
ort = None
from .schema_utils import collect_enums, collect_formats
from .types import Prediction
class SemanticReasoner:
def __init__(self, backend: str = "local", onnx_path: Optional[str] = None) -> None:
self.backend = backend
self.onnx_path = onnx_path or os.getenv("SEMVAL_ONNX_PATH")
self._session = None
if backend == "onnx":
if ort is None:
warnings.warn(
"ONNX backend requested but onnxruntime is not installed; falling back to local heuristics.",
stacklevel=2,
)
elif not self.onnx_path:
warnings.warn(
"ONNX backend requested but SEMVAL_ONNX_PATH is not set; falling back to local heuristics.",
stacklevel=2,
)
else:
try:
self._session = ort.InferenceSession(self.onnx_path)
except Exception as e:
warnings.warn(
f"Failed to initialize ONNXRuntime session ({e}); falling back to local heuristics.",
stacklevel=2,
)
self._session = None
def predict(
self, schema_str: str, json_str: str, rule_errors: List[dict]
) -> List[Prediction]:
if not rule_errors:
return []
try:
schema = json.loads(schema_str)
json.loads(json_str) # validate but don't store
except Exception:
schema = {}
enum_map = collect_enums(schema)
fmt_map = collect_formats(schema)
predictions: List[Prediction] = []
# Heuristic baseline: map error validator/message to a plausible fix
for err in rule_errors:
jsonpath = err.get("jsonpath", "$")
validator = err.get("validator")
message = err.get("message", "")
if validator == "type":
# guess bool vs number vs date
if any(k in fmt_map for k in [jsonpath.replace("$", "")]):
predictions.append(
{
"error_type": "invalid_date",
"jsonpath": jsonpath,
"fix_action": "parse_date_iso",
}
)
elif "boolean" in message:
predictions.append(
{
"error_type": "boolean_text",
"jsonpath": jsonpath,
"fix_action": "cast_bool",
}
)
elif "integer" in message or "number" in message:
predictions.append(
{
"error_type": "number_text",
"jsonpath": jsonpath,
"fix_action": "cast_number",
}
)
elif validator == "format":
# e.g., date format violations
val = err.get("validator_value")
if str(val) == "date":
predictions.append(
{
"error_type": "invalid_date",
"jsonpath": jsonpath,
"fix_action": "parse_date_iso",
}
)
else:
predictions.append(
{
"error_type": "wrong_type",
"jsonpath": jsonpath,
"fix_action": "fill_default",
}
)
elif validator == "enum":
allowed = enum_map.get(jsonpath.replace("$", ""), [])
predictions.append(
{
"error_type": "enum_near_miss",
"jsonpath": jsonpath,
"fix_action": "map_enum",
"fix_value": ",".join(allowed) if allowed else None,
}
)
elif validator == "required":
# When required property missing, try alias_key (heuristic)
predictions.append(
{
"error_type": "alias_key",
"jsonpath": jsonpath,
"fix_action": "rename_key",
}
)
elif validator == "minimum" or validator == "maximum":
# no direct fix; leave to rules
continue
else:
# fallback - no-op suggestion
predictions.append(
{
"error_type": "wrong_type",
"jsonpath": jsonpath,
"fix_action": "fill_default",
}
)
return predictions
|