thearnabsarkar commited on
Commit
46ae221
·
verified ·
1 Parent(s): a329232

Upload json_semval/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. json_semval/pipeline.py +132 -0
json_semval/pipeline.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import json
5
+ from typing import Any, Dict, List
6
+
7
+ from .fixes import cast_bool, cast_number, map_enum, parse_date_iso, rename_key
8
+ from .ml_model import SemanticReasoner
9
+ from .rules_engine import validate_with_jsonschema
10
+ from .schema_utils import collect_enums
11
+ from .types import Prediction, Report
12
+
13
+
14
+ def _apply_fix(schema: Dict[str, Any], payload: Any, pred: Prediction) -> Any | None:
15
+ path = pred.get("jsonpath", "$")
16
+ if not path.startswith("$"):
17
+ return None
18
+ # convert to tokens
19
+ tokens: List[str] = []
20
+ rest = path[1:]
21
+ i = 0
22
+ while i < len(rest):
23
+ ch = rest[i]
24
+ if ch == ".":
25
+ j = i + 1
26
+ name = []
27
+ while j < len(rest) and rest[j] not in ".[":
28
+ name.append(rest[j])
29
+ j += 1
30
+ if name:
31
+ tokens.append("." + "".join(name))
32
+ i = j
33
+ continue
34
+ if ch == "[":
35
+ j = rest.find("]", i)
36
+ tokens.append(rest[i : j + 1])
37
+ i = j + 1
38
+ continue
39
+ i += 1
40
+
41
+ action = pred.get("fix_action", "")
42
+ if action == "rename_key":
43
+ dst = pred.get("fix_value") or "_renamed"
44
+ try:
45
+ return rename_key(payload, tokens, dst)
46
+ except Exception:
47
+ return None
48
+ if action == "cast_number":
49
+ return cast_number(payload, tokens)
50
+ if action == "cast_bool":
51
+ return cast_bool(payload, tokens)
52
+ if action == "parse_date_iso":
53
+ return parse_date_iso(payload, tokens)
54
+ if action == "map_enum":
55
+ enums = collect_enums(schema)
56
+ allowed = enums.get(path.replace("$", ""), [])
57
+ return map_enum(payload, tokens, allowed)
58
+ # fill_default or unknown → skip
59
+ return None
60
+
61
+
62
+ def run_validation(
63
+ schema: Dict[str, Any],
64
+ payload: Any,
65
+ *,
66
+ apply_fixes: bool = True,
67
+ max_fixes: int = 5,
68
+ backend: str = "local",
69
+ ) -> Report:
70
+ is_valid, errors = validate_with_jsonschema(schema, payload)
71
+ if is_valid:
72
+ return {
73
+ "valid": True,
74
+ "rule_errors": [],
75
+ "ml_predictions": [],
76
+ "applied_fixes": [],
77
+ "corrected_json": payload,
78
+ }
79
+
80
+ # Honor explicit rules-only backend: do not invoke ML or apply fixes
81
+ if backend == "rules-only":
82
+ return {
83
+ "valid": False,
84
+ "rule_errors": errors,
85
+ "ml_predictions": [],
86
+ "applied_fixes": [],
87
+ "corrected_json": payload,
88
+ }
89
+
90
+ reasoner = SemanticReasoner(backend=backend)
91
+ preds = reasoner.predict(json.dumps(schema), json.dumps(payload), errors)
92
+ applied: List[Prediction] = []
93
+ corrected = copy.deepcopy(payload)
94
+
95
+ if not apply_fixes:
96
+ return {
97
+ "valid": False,
98
+ "rule_errors": errors,
99
+ "ml_predictions": preds,
100
+ "applied_fixes": [],
101
+ "corrected_json": corrected,
102
+ }
103
+
104
+ for pred in preds[:max_fixes]:
105
+ candidate = copy.deepcopy(corrected)
106
+ updated = _apply_fix(schema, candidate, pred)
107
+ if updated is None:
108
+ continue
109
+ now_valid, _ = validate_with_jsonschema(schema, candidate)
110
+ if now_valid:
111
+ corrected = candidate
112
+ applied.append(pred)
113
+ break
114
+ else:
115
+ # keep only if it reduces number of errors by any amount
116
+ prev_count = len(errors)
117
+ _, new_errs = validate_with_jsonschema(schema, candidate)
118
+ if len(new_errs) <= prev_count:
119
+ corrected = candidate
120
+ applied.append(pred)
121
+ errors = new_errs
122
+ if len(applied) >= max_fixes:
123
+ break
124
+
125
+ final_valid, final_errors = validate_with_jsonschema(schema, corrected)
126
+ return {
127
+ "valid": final_valid,
128
+ "rule_errors": final_errors if not final_valid else [],
129
+ "ml_predictions": preds,
130
+ "applied_fixes": applied,
131
+ "corrected_json": corrected,
132
+ }