Upload json_semval/pipeline.py with huggingface_hub
Browse files- json_semval/pipeline.py +132 -0
json_semval/pipeline.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
from .fixes import cast_bool, cast_number, map_enum, parse_date_iso, rename_key
|
| 8 |
+
from .ml_model import SemanticReasoner
|
| 9 |
+
from .rules_engine import validate_with_jsonschema
|
| 10 |
+
from .schema_utils import collect_enums
|
| 11 |
+
from .types import Prediction, Report
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _apply_fix(schema: Dict[str, Any], payload: Any, pred: Prediction) -> Any | None:
|
| 15 |
+
path = pred.get("jsonpath", "$")
|
| 16 |
+
if not path.startswith("$"):
|
| 17 |
+
return None
|
| 18 |
+
# convert to tokens
|
| 19 |
+
tokens: List[str] = []
|
| 20 |
+
rest = path[1:]
|
| 21 |
+
i = 0
|
| 22 |
+
while i < len(rest):
|
| 23 |
+
ch = rest[i]
|
| 24 |
+
if ch == ".":
|
| 25 |
+
j = i + 1
|
| 26 |
+
name = []
|
| 27 |
+
while j < len(rest) and rest[j] not in ".[":
|
| 28 |
+
name.append(rest[j])
|
| 29 |
+
j += 1
|
| 30 |
+
if name:
|
| 31 |
+
tokens.append("." + "".join(name))
|
| 32 |
+
i = j
|
| 33 |
+
continue
|
| 34 |
+
if ch == "[":
|
| 35 |
+
j = rest.find("]", i)
|
| 36 |
+
tokens.append(rest[i : j + 1])
|
| 37 |
+
i = j + 1
|
| 38 |
+
continue
|
| 39 |
+
i += 1
|
| 40 |
+
|
| 41 |
+
action = pred.get("fix_action", "")
|
| 42 |
+
if action == "rename_key":
|
| 43 |
+
dst = pred.get("fix_value") or "_renamed"
|
| 44 |
+
try:
|
| 45 |
+
return rename_key(payload, tokens, dst)
|
| 46 |
+
except Exception:
|
| 47 |
+
return None
|
| 48 |
+
if action == "cast_number":
|
| 49 |
+
return cast_number(payload, tokens)
|
| 50 |
+
if action == "cast_bool":
|
| 51 |
+
return cast_bool(payload, tokens)
|
| 52 |
+
if action == "parse_date_iso":
|
| 53 |
+
return parse_date_iso(payload, tokens)
|
| 54 |
+
if action == "map_enum":
|
| 55 |
+
enums = collect_enums(schema)
|
| 56 |
+
allowed = enums.get(path.replace("$", ""), [])
|
| 57 |
+
return map_enum(payload, tokens, allowed)
|
| 58 |
+
# fill_default or unknown → skip
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def run_validation(
|
| 63 |
+
schema: Dict[str, Any],
|
| 64 |
+
payload: Any,
|
| 65 |
+
*,
|
| 66 |
+
apply_fixes: bool = True,
|
| 67 |
+
max_fixes: int = 5,
|
| 68 |
+
backend: str = "local",
|
| 69 |
+
) -> Report:
|
| 70 |
+
is_valid, errors = validate_with_jsonschema(schema, payload)
|
| 71 |
+
if is_valid:
|
| 72 |
+
return {
|
| 73 |
+
"valid": True,
|
| 74 |
+
"rule_errors": [],
|
| 75 |
+
"ml_predictions": [],
|
| 76 |
+
"applied_fixes": [],
|
| 77 |
+
"corrected_json": payload,
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
# Honor explicit rules-only backend: do not invoke ML or apply fixes
|
| 81 |
+
if backend == "rules-only":
|
| 82 |
+
return {
|
| 83 |
+
"valid": False,
|
| 84 |
+
"rule_errors": errors,
|
| 85 |
+
"ml_predictions": [],
|
| 86 |
+
"applied_fixes": [],
|
| 87 |
+
"corrected_json": payload,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
reasoner = SemanticReasoner(backend=backend)
|
| 91 |
+
preds = reasoner.predict(json.dumps(schema), json.dumps(payload), errors)
|
| 92 |
+
applied: List[Prediction] = []
|
| 93 |
+
corrected = copy.deepcopy(payload)
|
| 94 |
+
|
| 95 |
+
if not apply_fixes:
|
| 96 |
+
return {
|
| 97 |
+
"valid": False,
|
| 98 |
+
"rule_errors": errors,
|
| 99 |
+
"ml_predictions": preds,
|
| 100 |
+
"applied_fixes": [],
|
| 101 |
+
"corrected_json": corrected,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
for pred in preds[:max_fixes]:
|
| 105 |
+
candidate = copy.deepcopy(corrected)
|
| 106 |
+
updated = _apply_fix(schema, candidate, pred)
|
| 107 |
+
if updated is None:
|
| 108 |
+
continue
|
| 109 |
+
now_valid, _ = validate_with_jsonschema(schema, candidate)
|
| 110 |
+
if now_valid:
|
| 111 |
+
corrected = candidate
|
| 112 |
+
applied.append(pred)
|
| 113 |
+
break
|
| 114 |
+
else:
|
| 115 |
+
# keep only if it reduces number of errors by any amount
|
| 116 |
+
prev_count = len(errors)
|
| 117 |
+
_, new_errs = validate_with_jsonschema(schema, candidate)
|
| 118 |
+
if len(new_errs) <= prev_count:
|
| 119 |
+
corrected = candidate
|
| 120 |
+
applied.append(pred)
|
| 121 |
+
errors = new_errs
|
| 122 |
+
if len(applied) >= max_fixes:
|
| 123 |
+
break
|
| 124 |
+
|
| 125 |
+
final_valid, final_errors = validate_with_jsonschema(schema, corrected)
|
| 126 |
+
return {
|
| 127 |
+
"valid": final_valid,
|
| 128 |
+
"rule_errors": final_errors if not final_valid else [],
|
| 129 |
+
"ml_predictions": preds,
|
| 130 |
+
"applied_fixes": applied,
|
| 131 |
+
"corrected_json": corrected,
|
| 132 |
+
}
|