File size: 4,967 Bytes
77d636f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import sys
import io
import multiprocessing
import contextlib
import signal
import subprocess
import tempfile
import shutil
# 增加超时控制
class TimeoutException(Exception): pass
def timeout_handler(signum, frame):
raise TimeoutException
def _exec_code(code_str, test_code, entry_point, result_queue):
"""
运行生成的代码 + 测试用例
"""
capture = io.StringIO()
success = False
error_msg = ""
try:
# 简单的超时机制 (Linux only, Windows需要其他方式)
# signal.signal(signal.SIGALRM, timeout_handler)
# signal.alarm(2) # 2秒超时
with contextlib.redirect_stdout(capture), contextlib.redirect_stderr(capture):
# 创建独立命名空间
scope = {}
# 1. 执行生成的代码 (定义函数)
exec(code_str, scope)
# 2. 检查入口函数是否存在
if entry_point not in scope:
raise ValueError(f"Entry point {entry_point} not found in generated code.")
# 3. 执行测试用例
# HumanEval 的测试用例通常是 "check(entry_point_func)" 的形式
# 我们需要把 check 函数定义也 exec 进去,或者拼接到一起
full_test_script = code_str + "\n" + test_code + f"\ncheck({entry_point})"
exec(full_test_script, scope)
success = True
except Exception as e:
error_msg = str(e)
# finally:
# signal.alarm(0)
result_queue.put((success, error_msg))
class SafeSandbox:
def __init__(self, timeout=5.0):
self.timeout = timeout
def run(self, code, test_code, entry_point):
queue = multiprocessing.Queue()
p = multiprocessing.Process(target=_exec_code, args=(code, test_code, entry_point, queue))
p.start()
p.join(self.timeout)
if p.is_alive():
p.terminate()
p.join()
return False, "Timeout"
if not queue.empty():
return queue.get()
return False, "Unknown Error"
class JavaSandbox:
def __init__(self, timeout=5.0):
self.timeout = timeout
# 检查 Java 环境
if shutil.which("javac") is None or shutil.which("java") is None:
raise RuntimeError("Java environment (jdk) not found. Please install java.")
def run(self, code, test_code, entry_point):
"""
code: 修复后的 Java 方法代码
test_code: 包含 main 函数的测试类代码,调用 entry_point
entry_point: 方法名 (Java 中通常不需要,只要 test_code 写对)
"""
# 创建临时目录
with tempfile.TemporaryDirectory() as temp_dir:
file_name = "Solution.java" # 假设类名是 Solution
file_path = os.path.join(temp_dir, file_name)
# 拼接代码:我们需要把生成的 method 塞进一个 Class 里
# 这里假设 code 只是一个 method,test_code 是 main 函数
# 你需要根据数据集的具体格式调整拼接逻辑
full_source = f"""
public class Solution {{
{code}
{test_code}
}}
"""
# 1. 写入文件
with open(file_path, "w") as f:
f.write(full_source)
# 2. 编译
compile_cmd = ["javac", file_path]
try:
subprocess.run(compile_cmd, check=True, capture_output=True, timeout=10)
except subprocess.CalledProcessError as e:
return False, f"Compilation Error: {e.stderr.decode()}"
except subprocess.TimeoutExpired:
return False, "Compilation Timeout"
# 3. 运行
run_cmd = ["java", "-cp", temp_dir, "Solution"]
try:
result = subprocess.run(run_cmd, capture_output=True, timeout=self.timeout)
if result.returncode == 0:
return True, result.stdout.decode()
else:
return False, f"Runtime Error: {result.stderr.decode()}"
except subprocess.TimeoutExpired:
return False, "Runtime Timeout"
# 单元测试
if __name__ == "__main__":
sandbox = JavaSandbox()
# 正确代码
code_pass = """
public static int add(int a, int b) {
return a + b;
}
"""
test_pass = """
public static void main(String[] args) {
if (add(1, 1) == 2) {
System.out.println("PASS");
} else {
System.exit(1);
}
}
"""
print("Test Pass:", sandbox.run(code_pass, test_pass, "add"))
# 错误代码
code_fail = """
public static int add(int a, int b) {
return a * b; // Bug
}
"""
print("Test Fail:", sandbox.run(code_fail, test_pass, "add")) |