Diff-Refine / src /search.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import torch
import torch.nn.functional as F
from tqdm import tqdm
class DiffuMCTS:
"""
Diffusion-based Monte Carlo Tree Search (Simulation-based).
Flow Matching Generation -> Rollout.
Sandbox to evaluation.
"""
def __init__(self, ae, flow, tokenizer, sandbox, device, config):
self.ae = ae
self.flow = flow
self.tokenizer = tokenizer
self.sandbox = sandbox
self.device = device
# 搜索配置
self.num_branches = 8 # 分支数量 (K)
self.split_t = 0.5 # 在哪个时间点分叉 (0=Target, 1=Source)
self.noise_scale = 0.1 # 分支时的扰动强度
self.steps = 10 # Flow ODE 步数
@torch.no_grad()
def solve(self, buggy_code, test_code, entry_point):
"""
对外接口:尝试修复代码
Returns:
fixed_code (str): 修复后的代码
success (bool): 是否通过测试
"""
# 1. Encode Buggy Code (Source) -> z_1
tokens = self.tokenizer(
buggy_code,
max_length=2048,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(self.device)
z_buggy = self.ae.encode(tokens['input_ids'], tokens['attention_mask'])
# 2. Search Strategy: Parallel Branching
# 我们执行一次带有分支的推理
best_code, success = self._parallel_rollout(z_buggy, test_code, entry_point)
return best_code, success
def _parallel_rollout(self, z_start, test_code, entry_point):
"""
执行并行 Rollout 搜索
"""
B, L, D = z_start.shape
K = self.num_branches
# --- Stage 1: Deterministic Flow (1.0 -> split_t) ---
# 先从 Buggy 状态走几步,让语义稍微稳定一点
z_curr = z_start.clone()
z_cond = z_start.clone() # Condition 始终是 Buggy Code
dt = 1.0 / self.steps
# 计算 split 对应的步数索引
split_step_idx = int((1.0 - self.split_t) * self.steps)
# 走前半程
for i in range(split_step_idx):
t_val = 1.0 - (i / self.steps) # 从 1 走向 0
t_tensor = torch.ones(B, device=self.device) * t_val
# ODE Step: z_next = z_prev - v * dt (注意时间方向)
# Rectified Flow 定义 v = z_1 - z_0 (从 Good 到 Bad 的反向? 或者是 Bad 到 Good?)
# 我们训练时: z_t = (1-t)z_bad + t*z_good.
# 所以 t=0 是 Bad, t=1 是 Good.
# 为了方便,我们定义 t 从 0 (Bad) 走向 1 (Good)。
# 修正逻辑:
# Forward Euler: z_{t+dt} = z_t + v * dt
# t 从 0 增加到 split_t
current_t_val = i / self.steps
t_tensor = torch.ones(B, device=self.device) * current_t_val
v = self.flow(z_curr, t_tensor, condition=z_cond)
z_curr = z_curr + v * dt
# --- Stage 2: Expansion (Branching) ---
# 复制 K 份
# [B, L, D] -> [B*K, L, D]
z_branches = z_curr.repeat(K, 1, 1)
z_cond_branches = z_cond.repeat(K, 1, 1)
# 注入高斯噪声 (Exploration)
# z' = z + noise
noise = torch.randn_like(z_branches) * self.noise_scale
z_branches = z_branches + noise
# 重新投影回球面 (保持流形约束)
z_branches = F.normalize(z_branches, p=2, dim=-1)
# --- Stage 3: Rollout (split_t -> 1.0) ---
# 并行推演所有分支
remaining_steps = self.steps - split_step_idx
for i in range(remaining_steps):
step_idx = split_step_idx + i
current_t_val = step_idx / self.steps
# [B*K]
t_tensor = torch.ones(z_branches.shape[0], device=self.device) * current_t_val
v = self.flow(z_branches, t_tensor, condition=z_cond_branches)
z_branches = z_branches + v * dt
# --- Stage 4: Decoding & Verification ---
# 批量解码
# [B*K, L, D] -> [B*K, L, Vocab]
logits = self.ae.decode(z_branches)
pred_ids = torch.argmax(logits, dim=-1)
candidate_codes = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
# 验证 Loop
# 只要有一个通过,就算成功 (Pass@k)
for code in candidate_codes:
is_pass, msg = self.sandbox.run(code, test_code, entry_point)
if is_pass:
return code, True
# 如果都失败,返回第一个(或者可以设计 heuristic 选择最接近的)
return candidate_codes[0], False