File size: 4,967 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import sys
import io
import multiprocessing
import contextlib
import signal
import subprocess
import tempfile
import shutil

# 增加超时控制
class TimeoutException(Exception): pass

def timeout_handler(signum, frame):
    raise TimeoutException

def _exec_code(code_str, test_code, entry_point, result_queue):
    """
    运行生成的代码 + 测试用例
    """
    capture = io.StringIO()
    success = False
    error_msg = ""
    
    try:
        # 简单的超时机制 (Linux only, Windows需要其他方式)
        # signal.signal(signal.SIGALRM, timeout_handler)
        # signal.alarm(2) # 2秒超时
        
        with contextlib.redirect_stdout(capture), contextlib.redirect_stderr(capture):
            # 创建独立命名空间
            scope = {}
            
            # 1. 执行生成的代码 (定义函数)
            exec(code_str, scope)
            
            # 2. 检查入口函数是否存在
            if entry_point not in scope:
                raise ValueError(f"Entry point {entry_point} not found in generated code.")
            
            # 3. 执行测试用例
            # HumanEval 的测试用例通常是 "check(entry_point_func)" 的形式
            # 我们需要把 check 函数定义也 exec 进去,或者拼接到一起
            full_test_script = code_str + "\n" + test_code + f"\ncheck({entry_point})"
            
            exec(full_test_script, scope)
            
            success = True
            
    except Exception as e:
        error_msg = str(e)
    # finally:
        # signal.alarm(0)
    
    result_queue.put((success, error_msg))

class SafeSandbox:
    def __init__(self, timeout=5.0):
        self.timeout = timeout

    def run(self, code, test_code, entry_point):
        queue = multiprocessing.Queue()
        p = multiprocessing.Process(target=_exec_code, args=(code, test_code, entry_point, queue))
        p.start()
        p.join(self.timeout)

        if p.is_alive():
            p.terminate()
            p.join()
            return False, "Timeout"

        if not queue.empty():
            return queue.get()
        return False, "Unknown Error"

class JavaSandbox:
    def __init__(self, timeout=5.0):
        self.timeout = timeout
        
        # 检查 Java 环境
        if shutil.which("javac") is None or shutil.which("java") is None:
            raise RuntimeError("Java environment (jdk) not found. Please install java.")

    def run(self, code, test_code, entry_point):
        """
        code: 修复后的 Java 方法代码
        test_code: 包含 main 函数的测试类代码,调用 entry_point
        entry_point: 方法名 (Java 中通常不需要,只要 test_code 写对)
        """
        # 创建临时目录
        with tempfile.TemporaryDirectory() as temp_dir:
            file_name = "Solution.java" # 假设类名是 Solution
            file_path = os.path.join(temp_dir, file_name)
            
            # 拼接代码:我们需要把生成的 method 塞进一个 Class 里
            # 这里假设 code 只是一个 method,test_code 是 main 函数
            # 你需要根据数据集的具体格式调整拼接逻辑
            
            full_source = f"""
public class Solution {{
    {code}

    {test_code}
}}
"""
            # 1. 写入文件
            with open(file_path, "w") as f:
                f.write(full_source)
            
            # 2. 编译
            compile_cmd = ["javac", file_path]
            try:
                subprocess.run(compile_cmd, check=True, capture_output=True, timeout=10)
            except subprocess.CalledProcessError as e:
                return False, f"Compilation Error: {e.stderr.decode()}"
            except subprocess.TimeoutExpired:
                return False, "Compilation Timeout"

            # 3. 运行
            run_cmd = ["java", "-cp", temp_dir, "Solution"]
            try:
                result = subprocess.run(run_cmd, capture_output=True, timeout=self.timeout)
                if result.returncode == 0:
                    return True, result.stdout.decode()
                else:
                    return False, f"Runtime Error: {result.stderr.decode()}"
            except subprocess.TimeoutExpired:
                return False, "Runtime Timeout"

# 单元测试
if __name__ == "__main__":
    sandbox = JavaSandbox()
    
    # 正确代码
    code_pass = """
    public static int add(int a, int b) {
        return a + b;
    }
    """
    test_pass = """
    public static void main(String[] args) {
        if (add(1, 1) == 2) {
            System.out.println("PASS");
        } else {
            System.exit(1);
        }
    }
    """
    print("Test Pass:", sandbox.run(code_pass, test_pass, "add"))
    
    # 错误代码
    code_fail = """
    public static int add(int a, int b) {
        return a * b; // Bug
    }
    """
    print("Test Fail:", sandbox.run(code_fail, test_pass, "add"))