import pytest from src.pytorch_debug_env.bug_library import BUG_TEMPLATES from src.pytorch_debug_env.environment import PyTorchDebugEnv from src.pytorch_debug_env.models import ( FinalDiagnosis, Hypothesis, InvestigationAction, PyTorchDebugAction, ) from src.pytorch_debug_env.scenario_generator import ScenarioGenerator def make_env(): generator = ScenarioGenerator(BUG_TEMPLATES) return PyTorchDebugEnv(generator=generator) def base_hypothesis(): return Hypothesis( bug_type="missing_zero_grad", affected_file="train.py", confidence=0.6, ) def final_diagnosis(): return FinalDiagnosis( bug_type="missing_zero_grad", affected_file="train.py", line_range=[9, 14], fix_strategy="Call optimizer.zero_grad() before loss.backward()", confidence=0.7, ) @pytest.mark.asyncio async def test_state_before_reset_returns_none(): env = make_env() assert await env.state() is None @pytest.mark.asyncio async def test_step_without_reset_raises(): env = make_env() action = PyTorchDebugAction(current_hypothesis=base_hypothesis()) with pytest.raises(RuntimeError): await env.step(action) @pytest.mark.asyncio async def test_reveal_file_adds_to_observation(): env = make_env() await env.reset("easy") target = "data/dataset.py" action = PyTorchDebugAction( current_hypothesis=base_hypothesis(), investigation_action=InvestigationAction(action="reveal_file", target=target), ) result = await env.step(action) assert target in result["observation"].revealed_files @pytest.mark.asyncio async def test_step_after_done_raises(): env = make_env() await env.reset("easy") action = PyTorchDebugAction( current_hypothesis=base_hypothesis(), commit_diagnosis=True, final_diagnosis=final_diagnosis(), ) await env.step(action) with pytest.raises(RuntimeError): await env.step(action) @pytest.mark.asyncio async def test_reward_range_and_info_keys(): env = make_env() await env.reset("easy") action = PyTorchDebugAction( current_hypothesis=base_hypothesis(), investigation_action=InvestigationAction( action="reveal_file", target="model/attention.py", ), ) result = await env.step(action) assert 0.0 < result["reward"] < 1.0 for key in ( "hypothesis_quality", "hypothesis_delta", "investigation_reward", "diagnosis_reward", "confirmation_bonus", ): assert key in result["info"] @pytest.mark.asyncio async def test_extend_loss_curve_increases_window(): env = make_env() await env.reset("easy", seed=123) action = PyTorchDebugAction( current_hypothesis=base_hypothesis(), investigation_action=InvestigationAction(action="extend_loss_curve"), ) extended = await env.step(action) extended_len = len(extended["observation"].loss_curve_window) env_base = make_env() await env_base.reset("easy", seed=123) base = await env_base.step(PyTorchDebugAction(current_hypothesis=base_hypothesis())) base_len = len(base["observation"].loss_curve_window) assert extended_len > base_len @pytest.mark.asyncio async def test_extend_gpu_profile_increases_window(): env = make_env() await env.reset("easy", seed=321) action = PyTorchDebugAction( current_hypothesis=base_hypothesis(), investigation_action=InvestigationAction(action="extend_gpu_profile"), ) extended = await env.step(action) extended_len = len(extended["observation"].gpu_profile_window) env_base = make_env() await env_base.reset("easy", seed=321) base = await env_base.step(PyTorchDebugAction(current_hypothesis=base_hypothesis())) base_len = len(base["observation"].gpu_profile_window) assert extended_len > base_len @pytest.mark.asyncio async def test_reveal_log_chunk_extends_tail(): env = make_env() await env.reset("easy", seed=77) action = PyTorchDebugAction( current_hypothesis=base_hypothesis(), investigation_action=InvestigationAction(action="reveal_log_chunk"), ) extended = await env.step(action) extended_len = len(extended["observation"].training_log_tail) env_base = make_env() await env_base.reset("easy", seed=77) base = await env_base.step(PyTorchDebugAction(current_hypothesis=base_hypothesis())) base_len = len(base["observation"].training_log_tail) assert extended_len >= base_len @pytest.mark.asyncio async def test_run_diagnostic_exposes_report(): env = make_env() await env.reset("easy", seed=11) action = PyTorchDebugAction( current_hypothesis=base_hypothesis(), investigation_action=InvestigationAction(action="run_diagnostic"), ) result = await env.step(action) assert result["observation"].diagnostic_report