import torch import torch.nn.functional as F from tqdm import tqdm class DiffuMCTS: """ Diffusion-based Monte Carlo Tree Search (Simulation-based). Flow Matching Generation -> Rollout. Sandbox to evaluation. """ def __init__(self, ae, flow, tokenizer, sandbox, device, config): self.ae = ae self.flow = flow self.tokenizer = tokenizer self.sandbox = sandbox self.device = device # 搜索配置 self.num_branches = 8 # 分支数量 (K) self.split_t = 0.5 # 在哪个时间点分叉 (0=Target, 1=Source) self.noise_scale = 0.1 # 分支时的扰动强度 self.steps = 10 # Flow ODE 步数 @torch.no_grad() def solve(self, buggy_code, test_code, entry_point): """ 对外接口:尝试修复代码 Returns: fixed_code (str): 修复后的代码 success (bool): 是否通过测试 """ # 1. Encode Buggy Code (Source) -> z_1 tokens = self.tokenizer( buggy_code, max_length=2048, padding="max_length", truncation=True, return_tensors="pt" ).to(self.device) z_buggy = self.ae.encode(tokens['input_ids'], tokens['attention_mask']) # 2. Search Strategy: Parallel Branching # 我们执行一次带有分支的推理 best_code, success = self._parallel_rollout(z_buggy, test_code, entry_point) return best_code, success def _parallel_rollout(self, z_start, test_code, entry_point): """ 执行并行 Rollout 搜索 """ B, L, D = z_start.shape K = self.num_branches # --- Stage 1: Deterministic Flow (1.0 -> split_t) --- # 先从 Buggy 状态走几步,让语义稍微稳定一点 z_curr = z_start.clone() z_cond = z_start.clone() # Condition 始终是 Buggy Code dt = 1.0 / self.steps # 计算 split 对应的步数索引 split_step_idx = int((1.0 - self.split_t) * self.steps) # 走前半程 for i in range(split_step_idx): t_val = 1.0 - (i / self.steps) # 从 1 走向 0 t_tensor = torch.ones(B, device=self.device) * t_val # ODE Step: z_next = z_prev - v * dt (注意时间方向) # Rectified Flow 定义 v = z_1 - z_0 (从 Good 到 Bad 的反向? 或者是 Bad 到 Good?) # 我们训练时: z_t = (1-t)z_bad + t*z_good. # 所以 t=0 是 Bad, t=1 是 Good. # 为了方便,我们定义 t 从 0 (Bad) 走向 1 (Good)。 # 修正逻辑: # Forward Euler: z_{t+dt} = z_t + v * dt # t 从 0 增加到 split_t current_t_val = i / self.steps t_tensor = torch.ones(B, device=self.device) * current_t_val v = self.flow(z_curr, t_tensor, condition=z_cond) z_curr = z_curr + v * dt # --- Stage 2: Expansion (Branching) --- # 复制 K 份 # [B, L, D] -> [B*K, L, D] z_branches = z_curr.repeat(K, 1, 1) z_cond_branches = z_cond.repeat(K, 1, 1) # 注入高斯噪声 (Exploration) # z' = z + noise noise = torch.randn_like(z_branches) * self.noise_scale z_branches = z_branches + noise # 重新投影回球面 (保持流形约束) z_branches = F.normalize(z_branches, p=2, dim=-1) # --- Stage 3: Rollout (split_t -> 1.0) --- # 并行推演所有分支 remaining_steps = self.steps - split_step_idx for i in range(remaining_steps): step_idx = split_step_idx + i current_t_val = step_idx / self.steps # [B*K] t_tensor = torch.ones(z_branches.shape[0], device=self.device) * current_t_val v = self.flow(z_branches, t_tensor, condition=z_cond_branches) z_branches = z_branches + v * dt # --- Stage 4: Decoding & Verification --- # 批量解码 # [B*K, L, D] -> [B*K, L, Vocab] logits = self.ae.decode(z_branches) pred_ids = torch.argmax(logits, dim=-1) candidate_codes = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) # 验证 Loop # 只要有一个通过,就算成功 (Pass@k) for code in candidate_codes: is_pass, msg = self.sandbox.run(code, test_code, entry_point) if is_pass: return code, True # 如果都失败,返回第一个(或者可以设计 heuristic 选择最接近的) return candidate_codes[0], False