Delta-Vector commited on
Commit
e0b3d16
·
verified ·
1 Parent(s): 3accff4

Upload folder using huggingface_hub

Browse files
__pycache__/complex_json_output.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
complex_json_output.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datasets import load_dataset
3
+
4
+ import verifiers as vf
5
+
6
+
7
+ def load_environment(
8
+ num_train_examples=7000,
9
+ num_eval_examples=1000,
10
+ **kwargs
11
+ ):
12
+ """
13
+ Environment for verifying complex JSON output from models.
14
+
15
+ The task requires models to:
16
+ 1. Parse multi-question prompts
17
+ 2. Generate valid JSON responses
18
+ 3. Match the expected structure with correct keys and values
19
+
20
+ Reward structure (multiplicative to prevent local minima):
21
+ - If JSON fails to parse: reward = 0
22
+ - Otherwise:
23
+ * key_accuracy = (correct_keys) / (total_keys_in_response)
24
+ * value_accuracy = (correct_values) / (total_values_in_response)
25
+ * final_reward = key_accuracy * value_accuracy
26
+
27
+ This penalizes both missing keys/values AND adding extra incorrect ones.
28
+ """
29
+
30
+ # Load dataset from HuggingFace
31
+ dataset = load_dataset("Delta-Vector/Tauri-Complex-JSON-Formatting", split="train")
32
+
33
+ # Map to expected format - keep verification_info as string to avoid schema issues
34
+ def format_example(example):
35
+ return {
36
+ "question": example["prompt"],
37
+ "info": {"verification_info": example["verification_info"]}, # Keep as dict with string
38
+ }
39
+
40
+ dataset = dataset.map(format_example, remove_columns=dataset.column_names)
41
+
42
+ # Split into train and eval
43
+ train_dataset = dataset.select(range(num_train_examples))
44
+ eval_dataset = dataset.select(range(num_train_examples, num_train_examples + num_eval_examples))
45
+
46
+ # Custom extract function to parse JSON from code blocks or raw text
47
+ def extract_json_from_completion(completion):
48
+ """Extract JSON from completion, handling code blocks."""
49
+ if not completion:
50
+ return ""
51
+
52
+ # Get the last message content
53
+ if isinstance(completion, list) and len(completion) > 0:
54
+ content = completion[-1].get("content", "")
55
+ else:
56
+ content = str(completion)
57
+
58
+ # Try to extract from code blocks first (```json ... ``` or ``` ... ```)
59
+ import re
60
+ code_block_pattern = r"```(?:json)?\s*\n(.*?)\n```"
61
+ matches = re.findall(code_block_pattern, content, re.DOTALL)
62
+ if matches:
63
+ return matches[-1].strip() # Return last code block
64
+
65
+ # Otherwise return the content as-is
66
+ return content.strip()
67
+
68
+ # Use simple Parser with custom extract function
69
+ parser = vf.Parser(extract_fn=extract_json_from_completion)
70
+
71
+ def multiplicative_reward(completion, info, **kwargs) -> float:
72
+ """
73
+ Multiplicative reward: key_accuracy * value_accuracy.
74
+
75
+ Returns 0 if JSON fails to parse.
76
+ Otherwise:
77
+ - key_accuracy = (correct_keys) / (total_keys_in_response)
78
+ - value_accuracy = (correct_values) / (total_values_in_response)
79
+ - final_reward = key_accuracy * value_accuracy
80
+
81
+ This penalizes both missing correct items AND adding extra incorrect ones.
82
+ """
83
+ try:
84
+ response = parser.parse_answer(completion) or ""
85
+ response = response.strip()
86
+
87
+ # Check: Valid JSON format
88
+ if not response:
89
+ return 0.0
90
+
91
+ try:
92
+ parsed_response = json.loads(response)
93
+ except (json.JSONDecodeError, ValueError):
94
+ return 0.0
95
+
96
+ # Must be a dict
97
+ if not isinstance(parsed_response, dict):
98
+ return 0.0
99
+
100
+ # Parse ground truth from info
101
+ verification_info = json.loads(info["verification_info"])
102
+ ground_truth = verification_info["ground_truth"]
103
+
104
+ # Get all keys recursively with their full paths
105
+ def get_all_keys(d, prefix=""):
106
+ keys = set()
107
+ if isinstance(d, dict):
108
+ for k, v in d.items():
109
+ full_key = f"{prefix}.{k}" if prefix else k
110
+ keys.add(full_key)
111
+ keys.update(get_all_keys(v, full_key))
112
+ return keys
113
+
114
+ # Get all values recursively
115
+ def get_all_values(d):
116
+ values = []
117
+ if isinstance(d, dict):
118
+ for v in d.values():
119
+ if isinstance(v, dict):
120
+ values.extend(get_all_values(v))
121
+ elif isinstance(v, list):
122
+ values.extend(get_all_values({"_": item} for item in v))
123
+ else:
124
+ values.append(v)
125
+ return values
126
+
127
+ ground_truth_keys = get_all_keys(ground_truth)
128
+ response_keys = get_all_keys(parsed_response)
129
+
130
+ # Calculate key accuracy
131
+ if len(response_keys) == 0:
132
+ key_accuracy = 0.0
133
+ else:
134
+ correct_keys = len(ground_truth_keys & response_keys) # Intersection
135
+ key_accuracy = correct_keys / len(response_keys)
136
+
137
+ # Calculate value accuracy by checking each value at correct key paths
138
+ def get_value_at_path(d, path):
139
+ """Get value at a specific key path like 'a.b.c'"""
140
+ keys = path.split('.')
141
+ current = d
142
+ try:
143
+ for key in keys:
144
+ current = current[key]
145
+ return current
146
+ except (KeyError, TypeError):
147
+ return None
148
+
149
+ # Helper function to compare values with numeric type tolerance
150
+ def values_equal(a, b):
151
+ """Compare values with numeric type tolerance (25 == 25.0)"""
152
+ # Handle numeric comparison (int vs float)
153
+ if isinstance(a, (int, float)) and isinstance(b, (int, float)):
154
+ return a == b # Python handles int/float equality correctly
155
+ # For everything else, use strict equality
156
+ return a == b
157
+
158
+ # Only check values for keys that exist in both
159
+ common_keys = ground_truth_keys & response_keys
160
+ total_values_checked = len(response_keys) # We check all response keys
161
+
162
+ if total_values_checked == 0:
163
+ value_accuracy = 0.0
164
+ else:
165
+ correct_values = 0
166
+ for key_path in response_keys:
167
+ response_val = get_value_at_path(parsed_response, key_path)
168
+ ground_truth_val = get_value_at_path(ground_truth, key_path)
169
+
170
+ # If key exists in ground truth and values match
171
+ if ground_truth_val is not None and values_equal(response_val, ground_truth_val):
172
+ correct_values += 1
173
+
174
+ value_accuracy = correct_values / total_values_checked
175
+
176
+ # Multiply together
177
+ final_reward = key_accuracy * value_accuracy
178
+ return final_reward
179
+
180
+ except (AttributeError, TypeError, KeyError) as e:
181
+ return 0.0
182
+
183
+ def format_reward(completion, **kwargs) -> float:
184
+ """
185
+ Reward for valid JSON formatting.
186
+ Returns 0.33 for valid JSON dict, 0 for invalid.
187
+ """
188
+ try:
189
+ response = parser.parse_answer(completion) or ""
190
+ response = response.strip()
191
+
192
+ # Check if response is not empty
193
+ if not response:
194
+ return 0.0
195
+
196
+ # Try to parse as JSON
197
+ parsed = json.loads(response)
198
+
199
+ # Must be a dict (since ground truth is always a dict)
200
+ if not isinstance(parsed, dict):
201
+ return 0.0
202
+
203
+ return 0.33
204
+ except (json.JSONDecodeError, ValueError, TypeError):
205
+ return 0.0
206
+
207
+ def keys_match_reward(completion, info, **kwargs) -> float:
208
+ """
209
+ Metric: key accuracy (correct_keys / total_keys_in_response).
210
+ Returns the same key_accuracy used in multiplicative_reward.
211
+ """
212
+ try:
213
+ response = parser.parse_answer(completion) or ""
214
+ response = response.strip()
215
+
216
+ if not response:
217
+ return 0.0
218
+
219
+ parsed_response = json.loads(response)
220
+
221
+ if not isinstance(parsed_response, dict):
222
+ return 0.0
223
+
224
+ # Parse ground truth from info
225
+ verification_info = json.loads(info["verification_info"])
226
+ ground_truth = verification_info["ground_truth"]
227
+
228
+ # Get all keys from ground truth (recursively)
229
+ def get_all_keys(d, prefix=""):
230
+ keys = set()
231
+ if isinstance(d, dict):
232
+ for k, v in d.items():
233
+ full_key = f"{prefix}.{k}" if prefix else k
234
+ keys.add(full_key)
235
+ keys.update(get_all_keys(v, full_key))
236
+ return keys
237
+
238
+ ground_truth_keys = get_all_keys(ground_truth)
239
+ response_keys = get_all_keys(parsed_response)
240
+
241
+ if len(response_keys) == 0:
242
+ return 0.0
243
+
244
+ correct_keys = len(ground_truth_keys & response_keys)
245
+ return correct_keys / len(response_keys)
246
+
247
+ except (json.JSONDecodeError, ValueError, AttributeError, TypeError):
248
+ return 0.0
249
+
250
+ def values_match_reward(completion, info, **kwargs) -> float:
251
+ """
252
+ Metric: value accuracy (correct_values / total_values_in_response).
253
+ Returns the same value_accuracy used in multiplicative_reward.
254
+ """
255
+ try:
256
+ response = parser.parse_answer(completion) or ""
257
+ response = response.strip()
258
+
259
+ if not response:
260
+ return 0.0
261
+
262
+ parsed_response = json.loads(response)
263
+
264
+ if not isinstance(parsed_response, dict):
265
+ return 0.0
266
+
267
+ # Parse ground truth from info
268
+ verification_info = json.loads(info["verification_info"])
269
+ ground_truth = verification_info["ground_truth"]
270
+
271
+ # Helper function to compare values with numeric type tolerance
272
+ def values_equal(a, b):
273
+ if isinstance(a, (int, float)) and isinstance(b, (int, float)):
274
+ return a == b
275
+ return a == b
276
+
277
+ # Get all keys recursively
278
+ def get_all_keys(d, prefix=""):
279
+ keys = set()
280
+ if isinstance(d, dict):
281
+ for k, v in d.items():
282
+ full_key = f"{prefix}.{k}" if prefix else k
283
+ keys.add(full_key)
284
+ keys.update(get_all_keys(v, full_key))
285
+ return keys
286
+
287
+ def get_value_at_path(d, path):
288
+ keys = path.split('.')
289
+ current = d
290
+ try:
291
+ for key in keys:
292
+ current = current[key]
293
+ return current
294
+ except (KeyError, TypeError):
295
+ return None
296
+
297
+ response_keys = get_all_keys(parsed_response)
298
+
299
+ if len(response_keys) == 0:
300
+ return 0.0
301
+
302
+ correct_values = 0
303
+ for key_path in response_keys:
304
+ response_val = get_value_at_path(parsed_response, key_path)
305
+ ground_truth_val = get_value_at_path(ground_truth, key_path)
306
+
307
+ if ground_truth_val is not None and values_equal(response_val, ground_truth_val):
308
+ correct_values += 1
309
+
310
+ return correct_values / len(response_keys)
311
+
312
+ except (json.JSONDecodeError, ValueError, AttributeError, TypeError):
313
+ return 0.0
314
+
315
+ # Create rubric with multiplicative reward
316
+ # Keep individual functions for debugging/metrics but use multiplicative for training
317
+ rubric = vf.Rubric(
318
+ parser=parser,
319
+ funcs=[
320
+ multiplicative_reward, # Main reward - key_acc * value_acc
321
+ format_reward, # Metric only (weight 0)
322
+ keys_match_reward, # Metric only (weight 0)
323
+ values_match_reward, # Metric only (weight 0)
324
+ ],
325
+ weights=[1.0, 0.0, 0.0, 0.0] # Only multiplicative_reward counts
326
+ )
327
+
328
+ # Return SingleTurnEnv since this is a one-shot task
329
+ # No system prompt - let the dataset prompt speak for itself
330
+ vf_env = vf.SingleTurnEnv(
331
+ dataset=train_dataset,
332
+ eval_dataset=eval_dataset,
333
+ parser=parser,
334
+ rubric=rubric,
335
+ )
336
+
337
+ return vf_env
pyproject.toml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "complex-json-output"
3
+ description = "Environment for verifying complex JSON output formatting and correctness"
4
+ tags = ["json", "instruction-following", "verifiable-reward", "train", "eval"]
5
+ version = "0.1.0"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "verifiers>=0.1.5.post0",
9
+ "datasets",
10
+ ]
11
+
12
+ [build-system]
13
+ requires = ["hatchling"]
14
+ build-backend = "hatchling.build"
train_complex_json_output.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import verifiers as vf
2
+
3
+ """
4
+ # install
5
+ vf-install complex-json-output (-p /path/to/environments)
6
+
7
+ # quick eval
8
+ vf-eval complex-json-output (-m model_name in endpoints.py)
9
+
10
+ inference:
11
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 vf-vllm --model Qwen/Qwen2.5-1.5B-Instruct \
12
+ --data-parallel-size 6 --enforce-eager --disable-log-requests
13
+
14
+ training:
15
+ CUDA_VISIBLE_DEVICES=6,7 accelerate launch --num-processes 2 \
16
+ --config-file configs/zero3.yaml examples/grpo/train_complex_json_output.py
17
+ """
18
+
19
+ # Hyperparameters
20
+ HPARAMS = [
21
+ "per_device_train_batch_size",
22
+ "num_generations",
23
+ "gradient_accumulation_steps",
24
+ "max_tokens",
25
+ "max_seq_len",
26
+ "max_prompt_length",
27
+ "max_completion_length",
28
+ "temperature",
29
+ "learning_rate",
30
+ "max_steps",
31
+ "warmup_steps",
32
+ "eval_steps",
33
+ "save_steps",
34
+ "beta",
35
+ "loss_type",
36
+ ]
37
+
38
+ # Load environment
39
+ vf_env = vf.load_environment(
40
+ env_id="complex-json-output",
41
+ num_train_examples=8000, # Use subset for faster training
42
+ num_eval_examples=50
43
+ )
44
+
45
+ # Model configuration
46
+ model_name = "/raid/workspace/Mango/verifiers/MS3.2-0.35-Beta"
47
+ run_name = "complex-json-grpo_" + model_name.split("/")[-1].lower()
48
+
49
+ # Load model and tokenizer
50
+ model, tokenizer = vf.get_model_and_tokenizer(model_name)
51
+
52
+ # Training arguments
53
+ training_args = vf.grpo_defaults(run_name=run_name)
54
+
55
+ # Batch configuration
56
+ training_args.per_device_train_batch_size = 2
57
+ training_args.num_generations = 16
58
+ training_args.gradient_accumulation_steps = 2
59
+
60
+ # Generation parameters
61
+ training_args.max_tokens = 2048 # JSON can be long
62
+ training_args.max_seq_len = 16000
63
+ training_args.max_prompt_length = 8192 # Allow long prompts (questions can be lengthy)
64
+ training_args.max_completion_length = 4096 # Allow long completions
65
+ training_args.temperature = 1.0 # Full diversity for exploration
66
+
67
+ # Training schedule
68
+ training_args.learning_rate = 5e-6
69
+ training_args.max_steps = 1000
70
+ training_args.warmup_steps = 15
71
+
72
+ # Evaluation
73
+ training_args.eval_strategy = "none"
74
+ training_args.eval_steps = 50
75
+ training_args.per_device_eval_batch_size = 8
76
+
77
+ # Checkpointing
78
+ training_args.save_strategy = "steps"
79
+ training_args.save_steps = 100
80
+
81
+ # GRPO parameters
82
+ training_args.beta = 0.001 # Conservative KL penalty
83
+ training_args.loss_type = "dr_grpo" # Recommended: no length bias
84
+
85
+ # Logging
86
+ training_args.logging_steps = 1
87
+ training_args.log_completions = True
88
+ training_args.num_completions_to_print = 3
89
+ training_args.report_to = "wandb" # Disable wandb
90
+
91
+ # Create trainer
92
+ trainer = vf.GRPOTrainer(
93
+ model=model,
94
+ processing_class=tokenizer,
95
+ env=vf_env,
96
+ args=training_args,
97
+ peft_config=vf.lora_defaults(r=8, alpha=16), # Use LoRA for efficiency
98
+ )
99
+
100
+ # Train
101
+ trainer.train()