| | import json |
| | from datasets import load_dataset |
| |
|
| | import verifiers as vf |
| |
|
| |
|
| | def load_environment( |
| | num_train_examples=7000, |
| | num_eval_examples=1000, |
| | **kwargs |
| | ): |
| | """ |
| | Environment for verifying complex JSON output from models. |
| | |
| | The task requires models to: |
| | 1. Parse multi-question prompts |
| | 2. Generate valid JSON responses |
| | 3. Match the expected structure with correct keys and values |
| | |
| | Rewards (no penalties, only positive rewards): |
| | - Formatting (valid JSON dict): 0.33 if pass, 0 if fail |
| | - All keys match: 0.33 if pass, 0 if fail |
| | - Answer values match: 0.33 if pass, 0 if fail |
| | Total max reward: ~1.0 |
| | """ |
| | |
| | |
| | dataset = load_dataset("Delta-Vector/Tauri-Complex-JSON-Formatting", split="train") |
| | |
| | |
| | def format_example(example): |
| | return { |
| | "question": example["prompt"], |
| | "info": {"verification_info": example["verification_info"]}, |
| | } |
| | |
| | dataset = dataset.map(format_example, remove_columns=dataset.column_names) |
| | |
| | |
| | train_dataset = dataset.select(range(num_train_examples)) |
| | eval_dataset = dataset.select(range(num_train_examples, num_train_examples + num_eval_examples)) |
| | |
| | |
| | def extract_json_from_completion(completion): |
| | """Extract JSON from completion, handling code blocks.""" |
| | if not completion: |
| | return "" |
| | |
| | |
| | if isinstance(completion, list) and len(completion) > 0: |
| | content = completion[-1].get("content", "") |
| | else: |
| | content = str(completion) |
| | |
| | |
| | import re |
| | code_block_pattern = r"```(?:json)?\s*\n(.*?)\n```" |
| | matches = re.findall(code_block_pattern, content, re.DOTALL) |
| | if matches: |
| | return matches[-1].strip() |
| | |
| | |
| | return content.strip() |
| | |
| | |
| | parser = vf.Parser(extract_fn=extract_json_from_completion) |
| | |
| | def format_reward(completion, **kwargs) -> float: |
| | """ |
| | Reward for valid JSON formatting. |
| | Returns 0.33 for valid JSON dict, 0 for invalid. |
| | """ |
| | try: |
| | response = parser.parse_answer(completion) or "" |
| | response = response.strip() |
| | |
| | |
| | if not response: |
| | return 0.0 |
| | |
| | |
| | parsed = json.loads(response) |
| | |
| | |
| | if not isinstance(parsed, dict): |
| | return 0.0 |
| | |
| | return 0.33 |
| | except (json.JSONDecodeError, ValueError, TypeError): |
| | return 0.0 |
| | |
| | def keys_match_reward(completion, info, **kwargs) -> float: |
| | """ |
| | Reward for matching keys in the JSON structure. |
| | Returns 0.33 if all keys match, 0 otherwise. |
| | """ |
| | try: |
| | response = parser.parse_answer(completion) or "" |
| | response = response.strip() |
| | parsed_response = json.loads(response) |
| | |
| | |
| | verification_info = json.loads(info["verification_info"]) |
| | ground_truth = verification_info["ground_truth"] |
| | |
| | |
| | if not isinstance(parsed_response, dict): |
| | return 0.0 |
| | |
| | |
| | def get_all_keys(d, prefix=""): |
| | keys = set() |
| | if isinstance(d, dict): |
| | for k, v in d.items(): |
| | full_key = f"{prefix}.{k}" if prefix else k |
| | keys.add(full_key) |
| | keys.update(get_all_keys(v, full_key)) |
| | return keys |
| | |
| | expected_keys = get_all_keys(ground_truth) |
| | actual_keys = get_all_keys(parsed_response) |
| | |
| | |
| | if expected_keys == actual_keys: |
| | return 0.33 |
| | else: |
| | return 0.0 |
| | |
| | except (json.JSONDecodeError, ValueError, AttributeError, TypeError): |
| | return 0.0 |
| | |
| | def values_match_reward(completion, info, **kwargs) -> float: |
| | """ |
| | Reward for matching values in the JSON structure. |
| | Returns 0.33 if all values match, 0 otherwise. |
| | """ |
| | try: |
| | response = parser.parse_answer(completion) or "" |
| | response = response.strip() |
| | parsed_response = json.loads(response) |
| | |
| | |
| | verification_info = json.loads(info["verification_info"]) |
| | ground_truth = verification_info["ground_truth"] |
| | |
| | |
| | def deep_compare(a, b): |
| | if type(a) != type(b): |
| | return False |
| | if isinstance(a, dict): |
| | if set(a.keys()) != set(b.keys()): |
| | return False |
| | return all(deep_compare(a[k], b[k]) for k in a.keys()) |
| | elif isinstance(a, list): |
| | if len(a) != len(b): |
| | return False |
| | return all(deep_compare(a[i], b[i]) for i in range(len(a))) |
| | else: |
| | return a == b |
| | |
| | if deep_compare(parsed_response, ground_truth): |
| | return 0.33 |
| | else: |
| | return 0.0 |
| | |
| | except (json.JSONDecodeError, ValueError, AttributeError, TypeError): |
| | return 0.0 |
| | |
| | |
| | rubric = vf.Rubric( |
| | parser=parser, |
| | funcs=[ |
| | format_reward, |
| | keys_match_reward, |
| | values_match_reward, |
| | ], |
| | weights=[1.0, 1.0, 1.0] |
| | ) |
| | |
| | |
| | |
| | vf_env = vf.SingleTurnEnv( |
| | dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | parser=parser, |
| | rubric=rubric, |
| | ) |
| | |
| | return vf_env |
| |
|