| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| class DiffuMCTS: |
| """ |
| Diffusion-based Monte Carlo Tree Search (Simulation-based). |
| Flow Matching Generation -> Rollout. |
| Sandbox to evaluation. |
| """ |
| def __init__(self, ae, flow, tokenizer, sandbox, device, config): |
| self.ae = ae |
| self.flow = flow |
| self.tokenizer = tokenizer |
| self.sandbox = sandbox |
| self.device = device |
| |
| |
| self.num_branches = 8 |
| self.split_t = 0.5 |
| self.noise_scale = 0.1 |
| self.steps = 10 |
| |
| @torch.no_grad() |
| def solve(self, buggy_code, test_code, entry_point): |
| """ |
| 对外接口:尝试修复代码 |
| Returns: |
| fixed_code (str): 修复后的代码 |
| success (bool): 是否通过测试 |
| """ |
| |
| tokens = self.tokenizer( |
| buggy_code, |
| max_length=2048, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| z_buggy = self.ae.encode(tokens['input_ids'], tokens['attention_mask']) |
| |
| |
| |
| best_code, success = self._parallel_rollout(z_buggy, test_code, entry_point) |
| |
| return best_code, success |
|
|
| def _parallel_rollout(self, z_start, test_code, entry_point): |
| """ |
| 执行并行 Rollout 搜索 |
| """ |
| B, L, D = z_start.shape |
| K = self.num_branches |
| |
| |
| |
| z_curr = z_start.clone() |
| z_cond = z_start.clone() |
| |
| dt = 1.0 / self.steps |
| |
| split_step_idx = int((1.0 - self.split_t) * self.steps) |
| |
| |
| for i in range(split_step_idx): |
| t_val = 1.0 - (i / self.steps) |
| t_tensor = torch.ones(B, device=self.device) * t_val |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| current_t_val = i / self.steps |
| t_tensor = torch.ones(B, device=self.device) * current_t_val |
| |
| v = self.flow(z_curr, t_tensor, condition=z_cond) |
| z_curr = z_curr + v * dt |
|
|
| |
| |
| |
| z_branches = z_curr.repeat(K, 1, 1) |
| z_cond_branches = z_cond.repeat(K, 1, 1) |
| |
| |
| |
| noise = torch.randn_like(z_branches) * self.noise_scale |
| z_branches = z_branches + noise |
| |
| |
| z_branches = F.normalize(z_branches, p=2, dim=-1) |
|
|
| |
| |
| remaining_steps = self.steps - split_step_idx |
| |
| for i in range(remaining_steps): |
| step_idx = split_step_idx + i |
| current_t_val = step_idx / self.steps |
| |
| |
| t_tensor = torch.ones(z_branches.shape[0], device=self.device) * current_t_val |
| |
| v = self.flow(z_branches, t_tensor, condition=z_cond_branches) |
| z_branches = z_branches + v * dt |
|
|
| |
| |
| |
| logits = self.ae.decode(z_branches) |
| pred_ids = torch.argmax(logits, dim=-1) |
| |
| candidate_codes = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) |
| |
| |
| |
| for code in candidate_codes: |
| is_pass, msg = self.sandbox.run(code, test_code, entry_point) |
| if is_pass: |
| return code, True |
| |
| |
| return candidate_codes[0], False |