Openenv / tests /test_environment_edge_cases.py
Priyansh Saxena
feat: expand scenarios and investigation actions
1435892
import pytest
from src.pytorch_debug_env.bug_library import BUG_TEMPLATES
from src.pytorch_debug_env.environment import PyTorchDebugEnv
from src.pytorch_debug_env.models import (
FinalDiagnosis,
Hypothesis,
InvestigationAction,
PyTorchDebugAction,
)
from src.pytorch_debug_env.scenario_generator import ScenarioGenerator
def make_env():
generator = ScenarioGenerator(BUG_TEMPLATES)
return PyTorchDebugEnv(generator=generator)
def base_hypothesis():
return Hypothesis(
bug_type="missing_zero_grad",
affected_file="train.py",
confidence=0.6,
)
def final_diagnosis():
return FinalDiagnosis(
bug_type="missing_zero_grad",
affected_file="train.py",
line_range=[9, 14],
fix_strategy="Call optimizer.zero_grad() before loss.backward()",
confidence=0.7,
)
@pytest.mark.asyncio
async def test_state_before_reset_returns_none():
env = make_env()
assert await env.state() is None
@pytest.mark.asyncio
async def test_step_without_reset_raises():
env = make_env()
action = PyTorchDebugAction(current_hypothesis=base_hypothesis())
with pytest.raises(RuntimeError):
await env.step(action)
@pytest.mark.asyncio
async def test_reveal_file_adds_to_observation():
env = make_env()
await env.reset("easy")
target = "data/dataset.py"
action = PyTorchDebugAction(
current_hypothesis=base_hypothesis(),
investigation_action=InvestigationAction(action="reveal_file", target=target),
)
result = await env.step(action)
assert target in result["observation"].revealed_files
@pytest.mark.asyncio
async def test_step_after_done_raises():
env = make_env()
await env.reset("easy")
action = PyTorchDebugAction(
current_hypothesis=base_hypothesis(),
commit_diagnosis=True,
final_diagnosis=final_diagnosis(),
)
await env.step(action)
with pytest.raises(RuntimeError):
await env.step(action)
@pytest.mark.asyncio
async def test_reward_range_and_info_keys():
env = make_env()
await env.reset("easy")
action = PyTorchDebugAction(
current_hypothesis=base_hypothesis(),
investigation_action=InvestigationAction(
action="reveal_file",
target="model/attention.py",
),
)
result = await env.step(action)
assert 0.0 < result["reward"] < 1.0
for key in (
"hypothesis_quality",
"hypothesis_delta",
"investigation_reward",
"diagnosis_reward",
"confirmation_bonus",
):
assert key in result["info"]
@pytest.mark.asyncio
async def test_extend_loss_curve_increases_window():
env = make_env()
await env.reset("easy", seed=123)
action = PyTorchDebugAction(
current_hypothesis=base_hypothesis(),
investigation_action=InvestigationAction(action="extend_loss_curve"),
)
extended = await env.step(action)
extended_len = len(extended["observation"].loss_curve_window)
env_base = make_env()
await env_base.reset("easy", seed=123)
base = await env_base.step(PyTorchDebugAction(current_hypothesis=base_hypothesis()))
base_len = len(base["observation"].loss_curve_window)
assert extended_len > base_len
@pytest.mark.asyncio
async def test_extend_gpu_profile_increases_window():
env = make_env()
await env.reset("easy", seed=321)
action = PyTorchDebugAction(
current_hypothesis=base_hypothesis(),
investigation_action=InvestigationAction(action="extend_gpu_profile"),
)
extended = await env.step(action)
extended_len = len(extended["observation"].gpu_profile_window)
env_base = make_env()
await env_base.reset("easy", seed=321)
base = await env_base.step(PyTorchDebugAction(current_hypothesis=base_hypothesis()))
base_len = len(base["observation"].gpu_profile_window)
assert extended_len > base_len
@pytest.mark.asyncio
async def test_reveal_log_chunk_extends_tail():
env = make_env()
await env.reset("easy", seed=77)
action = PyTorchDebugAction(
current_hypothesis=base_hypothesis(),
investigation_action=InvestigationAction(action="reveal_log_chunk"),
)
extended = await env.step(action)
extended_len = len(extended["observation"].training_log_tail)
env_base = make_env()
await env_base.reset("easy", seed=77)
base = await env_base.step(PyTorchDebugAction(current_hypothesis=base_hypothesis()))
base_len = len(base["observation"].training_log_tail)
assert extended_len >= base_len
@pytest.mark.asyncio
async def test_run_diagnostic_exposes_report():
env = make_env()
await env.reset("easy", seed=11)
action = PyTorchDebugAction(
current_hypothesis=base_hypothesis(),
investigation_action=InvestigationAction(action="run_diagnostic"),
)
result = await env.step(action)
assert result["observation"].diagnostic_report