CounterFeint / tests /test_environment.py
QuantumTransformer's picture
Upload folder using huggingface_hub
28f702f verified
"""Tests for the core AdFraudEnvironment."""
from counterfeint.models import AdReviewAction, AdReviewObservation, AdFraudState
from counterfeint.server.environment import AdFraudEnvironment
class TestReset:
def test_reset_returns_observation(self):
env = AdFraudEnvironment()
obs = env.reset(seed=42, task_id="task_1")
assert isinstance(obs, AdReviewObservation)
assert obs.done is False
assert obs.reward == 0.0
assert len(obs.available_ads) == 5
def test_reset_clears_state(self):
env = AdFraudEnvironment()
env.reset(seed=42, task_id="task_1")
env.step(AdReviewAction(
action_type="verdict", ad_id="ad_001",
verdict="approve", confidence=0.9,
))
obs = env.reset(seed=42, task_id="task_1")
state = env.state
assert state.step_count == 0
assert state.reviewed_count == 0
assert len(obs.available_ads) == 5
def test_reset_different_tasks(self):
env = AdFraudEnvironment()
for task_id, expected in [("task_1", 5), ("task_2", 12), ("task_3", 20)]:
obs = env.reset(seed=42, task_id=task_id)
assert len(obs.available_ads) == expected
class TestStep:
def test_investigate_returns_findings(self):
env = AdFraudEnvironment()
env.reset(seed=42, task_id="task_1")
obs = env.step(AdReviewAction(
action_type="investigate",
ad_id="ad_001",
investigation_target="advertiser_history",
))
assert obs.done is False
assert obs.reward == -0.02
assert "Advertiser" in obs.feedback or "Investigation complete" in obs.feedback
def test_verdict_correct_rejection(self):
env = AdFraudEnvironment()
env.reset(seed=42, task_id="task_1")
fraud_ads = [
a for a in env._episode.ads if a.ground_truth_label == "fraud"
]
assert len(fraud_ads) > 0
ad = fraud_ads[0]
obs = env.step(AdReviewAction(
action_type="verdict", ad_id=ad.ad_id,
verdict="reject", confidence=0.9,
))
assert obs.reward > 0
def test_verdict_false_negative_penalty(self):
env = AdFraudEnvironment()
env.reset(seed=42, task_id="task_1")
fraud_ads = [
a for a in env._episode.ads if a.ground_truth_label == "fraud"
]
ad = fraud_ads[0]
obs = env.step(AdReviewAction(
action_type="verdict", ad_id=ad.ad_id,
verdict="approve", confidence=0.9,
))
assert obs.reward < 0
def test_duplicate_verdict_rejected(self):
env = AdFraudEnvironment()
env.reset(seed=42, task_id="task_1")
env.step(AdReviewAction(
action_type="verdict", ad_id="ad_001",
verdict="approve", confidence=0.5,
))
obs = env.step(AdReviewAction(
action_type="verdict", ad_id="ad_001",
verdict="reject", confidence=0.9,
))
assert obs.reward == -0.02
def test_invalid_ad_id(self):
env = AdFraudEnvironment()
env.reset(seed=42, task_id="task_1")
obs = env.step(AdReviewAction(
action_type="investigate", ad_id="ad_999",
investigation_target="landing_page",
))
assert obs.reward == -0.05
assert "Invalid" in obs.feedback
def test_episode_ends_when_all_reviewed(self):
env = AdFraudEnvironment()
obs = env.reset(seed=42, task_id="task_1")
for ad_id in list(obs.available_ads):
obs = env.step(AdReviewAction(
action_type="verdict", ad_id=ad_id,
verdict="reject", confidence=0.5,
))
assert obs.done is True
def test_step_after_done_returns_done(self):
env = AdFraudEnvironment()
obs = env.reset(seed=42, task_id="task_1")
for ad_id in list(obs.available_ads):
obs = env.step(AdReviewAction(
action_type="verdict", ad_id=ad_id,
verdict="reject", confidence=0.5,
))
obs = env.step(AdReviewAction(
action_type="investigate", ad_id="ad_001",
investigation_target="landing_page",
))
assert obs.done is True
assert "already complete" in obs.feedback.lower()
class TestState:
def test_state_tracks_progress(self):
env = AdFraudEnvironment()
env.reset(seed=42, task_id="task_1")
state = env.state
assert state.task_id == "task_1"
assert state.total_ads == 5
assert state.remaining_budget == 25
assert state.step_count == 0
env.step(AdReviewAction(
action_type="investigate", ad_id="ad_001",
investigation_target="landing_page",
))
state = env.state
assert state.step_count == 1
assert state.remaining_budget == 24
def test_grader_score_set_on_completion(self):
env = AdFraudEnvironment()
obs = env.reset(seed=42, task_id="task_1")
for ad_id in list(obs.available_ads):
env.step(AdReviewAction(
action_type="verdict", ad_id=ad_id,
verdict="reject", confidence=0.5,
))
state = env.state
assert state.grader_score is not None
assert 0.0 <= state.grader_score <= 1.0
class TestAntiExploit:
def test_always_reject_scores_poorly(self):
"""Always-reject on task_2 (5 legit / 5 fraud / 2 escalate) should be punished."""
env = AdFraudEnvironment()
obs = env.reset(seed=42, task_id="task_2")
for ad_id in list(obs.available_ads):
env.step(AdReviewAction(
action_type="verdict", ad_id=ad_id,
verdict="reject", confidence=0.9,
))
score = env.state.grader_score
assert score < 0.7, f"Always-reject should score poorly, got {score}"
def test_always_escalate_scores_poorly(self):
env = AdFraudEnvironment()
obs = env.reset(seed=42, task_id="task_1")
for ad_id in list(obs.available_ads):
env.step(AdReviewAction(
action_type="verdict", ad_id=ad_id,
verdict="escalate", confidence=0.5,
))
score = env.state.grader_score
assert score < 0.7, f"Always-escalate should score poorly, got {score}"