#!/usr/bin/env python3 """ Reference GRPO training script for agentic coding RL. Uses execution-verified pass_rate as the reward signal. Usage: python train_grpo.py \ --model ./nexus-coder-sft \ --output_dir ./nexus-coder-rl """ import argparse import json import subprocess import tempfile from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl import GRPOTrainer, GRPOConfig # --------------------------------------------------------------------------- # Execution reward function (simplified — adapt to your sandbox) # --------------------------------------------------------------------------- def execution_reward_fn(completions: list, **kwargs) -> list: """ Reward function for GRPO. Expects completions that contain bash commands or patches. In a real setup, replay commands in a Docker sandbox and return pass_rate. """ rewards = [] for completion in completions: try: # Look for ```bash ... ``` blocks if "```bash" in completion: cmd = completion.split("```bash")[-1].split("```")[0].strip() result = subprocess.run(cmd, shell=True, capture_output=True, timeout=30, cwd=tempfile.gettempdir()) reward = 1.0 if result.returncode == 0 else 0.0 else: reward = 0.0 except Exception: reward = 0.0 rewards.append(reward) return rewards # --------------------------------------------------------------------------- # Dataset prep # --------------------------------------------------------------------------- def load_rl_dataset(): """Load Nemotron RL SWE pivot dataset and normalize prompts.""" ds = load_dataset("nvidia/Nemotron-RL-Agentic-SWE-Pivot-v1", split="train") def normalize(example): params = example.get("responses_create_params", {}) inp = params.get("input", []) if len(inp) > 0 and isinstance(inp[0], dict): system = inp[0].get("content", "") ref = example.get("ref_message", {}) reasoning = ref.get("reasoning_content", "") if isinstance(ref, dict) else "" return { "prompt": system, "completion": reasoning, } return {"prompt": "", "completion": ""} ds = ds.map(normalize, remove_columns=ds.column_names) ds = ds.filter(lambda x: len(x["prompt"]) > 50) return ds # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", required=True, help="Path to SFT checkpoint") parser.add_argument("--output_dir", default="./nexus-coder-rl") parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--grad_accum", type=int, default=16) parser.add_argument("--lr", type=float, default=1e-6) parser.add_argument("--max_prompt_length", type=int, default=4096) parser.add_argument("--max_completion_length", type=int, default=12288) parser.add_argument("--num_generations", type=int, default=8) parser.add_argument("--hub_model_id", default=None) args = parser.parse_args() print("[1/4] Loading SFT model and tokenizer...") model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype="bfloat16", device_map="auto", trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("[2/4] Loading RL dataset...") dataset = load_rl_dataset() print(f" RL dataset size: {len(dataset)} examples") print("[3/4] Configuring GRPO trainer...") grpo_config = GRPOConfig( output_dir=args.output_dir, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, max_prompt_length=args.max_prompt_length, max_completion_length=args.max_completion_length, num_generations=args.num_generations, temperature=0.7, logging_strategy="steps", logging_steps=5, logging_first_step=True, bf16=True, gradient_checkpointing=True, disable_tqdm=True, push_to_hub=args.hub_model_id is not None, hub_model_id=args.hub_model_id, ) trainer = GRPOTrainer( model=model, reward_funcs=[execution_reward_fn], args=grpo_config, train_dataset=dataset, processing_class=tokenizer, ) print("[4/4] Starting GRPO training...") trainer.train() trainer.save_model(args.output_dir) print(f"Done. Model saved to {args.output_dir}") if __name__ == "__main__": main()