File size: 4,838 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn.functional as F
from tqdm import tqdm

class DiffuMCTS:
    """
    Diffusion-based Monte Carlo Tree Search (Simulation-based).
    Flow Matching Generation -> Rollout.
    Sandbox to evaluation.
    """
    def __init__(self, ae, flow, tokenizer, sandbox, device, config):
        self.ae = ae
        self.flow = flow
        self.tokenizer = tokenizer
        self.sandbox = sandbox
        self.device = device
        
        # 搜索配置
        self.num_branches = 8      # 分支数量 (K)
        self.split_t = 0.5         # 在哪个时间点分叉 (0=Target, 1=Source)
        self.noise_scale = 0.1     # 分支时的扰动强度
        self.steps = 10            # Flow ODE 步数
        
    @torch.no_grad()
    def solve(self, buggy_code, test_code, entry_point):
        """
        对外接口:尝试修复代码
        Returns:
            fixed_code (str): 修复后的代码
            success (bool): 是否通过测试
        """
        # 1. Encode Buggy Code (Source) -> z_1
        tokens = self.tokenizer(
            buggy_code, 
            max_length=2048, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        ).to(self.device)
        
        z_buggy = self.ae.encode(tokens['input_ids'], tokens['attention_mask'])
        
        # 2. Search Strategy: Parallel Branching
        # 我们执行一次带有分支的推理
        best_code, success = self._parallel_rollout(z_buggy, test_code, entry_point)
        
        return best_code, success

    def _parallel_rollout(self, z_start, test_code, entry_point):
        """
        执行并行 Rollout 搜索
        """
        B, L, D = z_start.shape
        K = self.num_branches
        
        # --- Stage 1: Deterministic Flow (1.0 -> split_t) ---
        # 先从 Buggy 状态走几步,让语义稍微稳定一点
        z_curr = z_start.clone()
        z_cond = z_start.clone() # Condition 始终是 Buggy Code
        
        dt = 1.0 / self.steps
        # 计算 split 对应的步数索引
        split_step_idx = int((1.0 - self.split_t) * self.steps)
        
        # 走前半程
        for i in range(split_step_idx):
            t_val = 1.0 - (i / self.steps) # 从 1 走向 0
            t_tensor = torch.ones(B, device=self.device) * t_val
            
            # ODE Step: z_next = z_prev - v * dt (注意时间方向)
            # Rectified Flow 定义 v = z_1 - z_0 (从 Good 到 Bad 的反向? 或者是 Bad 到 Good?)
            # 我们训练时: z_t = (1-t)z_bad + t*z_good. 
            # 所以 t=0 是 Bad, t=1 是 Good.
            # 为了方便,我们定义 t 从 0 (Bad) 走向 1 (Good)。
            
            # 修正逻辑:
            # Forward Euler: z_{t+dt} = z_t + v * dt
            # t 从 0 增加到 split_t
            current_t_val = i / self.steps
            t_tensor = torch.ones(B, device=self.device) * current_t_val
            
            v = self.flow(z_curr, t_tensor, condition=z_cond)
            z_curr = z_curr + v * dt

        # --- Stage 2: Expansion (Branching) ---
        # 复制 K 份
        # [B, L, D] -> [B*K, L, D]
        z_branches = z_curr.repeat(K, 1, 1)
        z_cond_branches = z_cond.repeat(K, 1, 1)
        
        # 注入高斯噪声 (Exploration)
        # z' = z + noise
        noise = torch.randn_like(z_branches) * self.noise_scale
        z_branches = z_branches + noise
        
        # 重新投影回球面 (保持流形约束)
        z_branches = F.normalize(z_branches, p=2, dim=-1)

        # --- Stage 3: Rollout (split_t -> 1.0) ---
        # 并行推演所有分支
        remaining_steps = self.steps - split_step_idx
        
        for i in range(remaining_steps):
            step_idx = split_step_idx + i
            current_t_val = step_idx / self.steps
            
            # [B*K]
            t_tensor = torch.ones(z_branches.shape[0], device=self.device) * current_t_val
            
            v = self.flow(z_branches, t_tensor, condition=z_cond_branches)
            z_branches = z_branches + v * dt

        # --- Stage 4: Decoding & Verification ---
        # 批量解码
        # [B*K, L, D] -> [B*K, L, Vocab]
        logits = self.ae.decode(z_branches)
        pred_ids = torch.argmax(logits, dim=-1)
        
        candidate_codes = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        
        # 验证 Loop
        # 只要有一个通过,就算成功 (Pass@k)
        for code in candidate_codes:
            is_pass, msg = self.sandbox.run(code, test_code, entry_point)
            if is_pass:
                return code, True
        
        # 如果都失败,返回第一个(或者可以设计 heuristic 选择最接近的)
        return candidate_codes[0], False