Diff-Refine / src /utils /sandbox.py
2ira's picture
Add files using upload-large-folder tool
77d636f verified
import sys
import io
import multiprocessing
import contextlib
import signal
import subprocess
import tempfile
import shutil
# 增加超时控制
class TimeoutException(Exception): pass
def timeout_handler(signum, frame):
raise TimeoutException
def _exec_code(code_str, test_code, entry_point, result_queue):
"""
运行生成的代码 + 测试用例
"""
capture = io.StringIO()
success = False
error_msg = ""
try:
# 简单的超时机制 (Linux only, Windows需要其他方式)
# signal.signal(signal.SIGALRM, timeout_handler)
# signal.alarm(2) # 2秒超时
with contextlib.redirect_stdout(capture), contextlib.redirect_stderr(capture):
# 创建独立命名空间
scope = {}
# 1. 执行生成的代码 (定义函数)
exec(code_str, scope)
# 2. 检查入口函数是否存在
if entry_point not in scope:
raise ValueError(f"Entry point {entry_point} not found in generated code.")
# 3. 执行测试用例
# HumanEval 的测试用例通常是 "check(entry_point_func)" 的形式
# 我们需要把 check 函数定义也 exec 进去,或者拼接到一起
full_test_script = code_str + "\n" + test_code + f"\ncheck({entry_point})"
exec(full_test_script, scope)
success = True
except Exception as e:
error_msg = str(e)
# finally:
# signal.alarm(0)
result_queue.put((success, error_msg))
class SafeSandbox:
def __init__(self, timeout=5.0):
self.timeout = timeout
def run(self, code, test_code, entry_point):
queue = multiprocessing.Queue()
p = multiprocessing.Process(target=_exec_code, args=(code, test_code, entry_point, queue))
p.start()
p.join(self.timeout)
if p.is_alive():
p.terminate()
p.join()
return False, "Timeout"
if not queue.empty():
return queue.get()
return False, "Unknown Error"
class JavaSandbox:
def __init__(self, timeout=5.0):
self.timeout = timeout
# 检查 Java 环境
if shutil.which("javac") is None or shutil.which("java") is None:
raise RuntimeError("Java environment (jdk) not found. Please install java.")
def run(self, code, test_code, entry_point):
"""
code: 修复后的 Java 方法代码
test_code: 包含 main 函数的测试类代码,调用 entry_point
entry_point: 方法名 (Java 中通常不需要,只要 test_code 写对)
"""
# 创建临时目录
with tempfile.TemporaryDirectory() as temp_dir:
file_name = "Solution.java" # 假设类名是 Solution
file_path = os.path.join(temp_dir, file_name)
# 拼接代码:我们需要把生成的 method 塞进一个 Class 里
# 这里假设 code 只是一个 method,test_code 是 main 函数
# 你需要根据数据集的具体格式调整拼接逻辑
full_source = f"""
public class Solution {{
{code}
{test_code}
}}
"""
# 1. 写入文件
with open(file_path, "w") as f:
f.write(full_source)
# 2. 编译
compile_cmd = ["javac", file_path]
try:
subprocess.run(compile_cmd, check=True, capture_output=True, timeout=10)
except subprocess.CalledProcessError as e:
return False, f"Compilation Error: {e.stderr.decode()}"
except subprocess.TimeoutExpired:
return False, "Compilation Timeout"
# 3. 运行
run_cmd = ["java", "-cp", temp_dir, "Solution"]
try:
result = subprocess.run(run_cmd, capture_output=True, timeout=self.timeout)
if result.returncode == 0:
return True, result.stdout.decode()
else:
return False, f"Runtime Error: {result.stderr.decode()}"
except subprocess.TimeoutExpired:
return False, "Runtime Timeout"
# 单元测试
if __name__ == "__main__":
sandbox = JavaSandbox()
# 正确代码
code_pass = """
public static int add(int a, int b) {
return a + b;
}
"""
test_pass = """
public static void main(String[] args) {
if (add(1, 1) == 2) {
System.out.println("PASS");
} else {
System.exit(1);
}
}
"""
print("Test Pass:", sandbox.run(code_pass, test_pass, "add"))
# 错误代码
code_fail = """
public static int add(int a, int b) {
return a * b; // Bug
}
"""
print("Test Fail:", sandbox.run(code_fail, test_pass, "add"))