Spaces:
Sleeping
Sleeping
| """ | |
| RefereeEnvironment — the central multi-agent orchestrator for CounterFeint. | |
| Owns a turn-based state machine with three roles: | |
| - Fraudster proposes / modifies ads (actions: propose_ad, modify_pending_ad, end_turn, commit_final) | |
| - Investigator reviews ads (actions: investigate, verdict, link_accounts) | |
| - Auditor audits the trace post-hoc (actions: flag_investigator, flag_fraudster, submit_audit_report) | |
| All three WebSocket endpoints (`/ws/fraudster`, `/ws/investigator`, `/ws/auditor`) | |
| share a single `RefereeEnvironment` instance per match, so state mutations | |
| from one role are immediately visible to the others. | |
| State machine: | |
| fraudster_turn ─end_turn──────► investigator_turn ─turn_cap/all_decided──► fraudster_turn (next round) | |
| │ │ | |
| ├─commit_final───► audit_phase ◄┘ | |
| │ │ | |
| └─action_cap──► investigator_turn │ max_rounds / budget / commit_final | |
| └──────── audit_phase → done ◄───────────────── | |
| Phase 1 keeps the Auditor a no-op scaffold (flags accepted, report accepted, but | |
| graders don't consume them yet). Phase 2A/B/C plug in real audit logic. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import random | |
| import time | |
| from typing import Any, Dict, List, Literal, Optional, Tuple | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import Action, Observation | |
| try: | |
| from ..data.ad_generator import ( | |
| TASK_CONFIGS, | |
| Ad, | |
| GeneratedEpisode, | |
| generate_episode, | |
| ) | |
| from ..data.episode_loader import extend_episode_with_proposal | |
| from ..data.tool_registry import INVESTIGATION_TARGETS, InvestigationToolRegistry | |
| from ..graders.auditor_track_a import ( | |
| investigator_audit_score as track_a_score, | |
| run_track_a, | |
| ) | |
| from ..graders.base_grader import ( | |
| EpisodeRecord, | |
| LinkResult, | |
| VerdictResult, | |
| grade_episode, | |
| ) | |
| from ..graders.multi_agent_rewards import ( | |
| RewardInputs, | |
| compute_episode_rewards, | |
| ) | |
| from ..graders.plausibility_score import compute_queue_plausibility | |
| from ..models import ( | |
| AdFraudState, | |
| AdReviewAction, | |
| AdReviewObservation, | |
| AuditFlag, | |
| AuditorAction, | |
| AuditorObservation, | |
| AuditReport, | |
| FraudsterAction, | |
| FraudsterObservation, | |
| RefereeState, | |
| ) | |
| from .environment import InvestigatorEnvironment | |
| from .evidence_ledger import build_evidence_ledger | |
| except ImportError: | |
| from data.ad_generator import ( | |
| TASK_CONFIGS, | |
| Ad, | |
| GeneratedEpisode, | |
| generate_episode, | |
| ) | |
| from data.episode_loader import extend_episode_with_proposal | |
| from data.tool_registry import INVESTIGATION_TARGETS, InvestigationToolRegistry | |
| from graders.auditor_track_a import ( | |
| investigator_audit_score as track_a_score, | |
| run_track_a, | |
| ) | |
| from graders.base_grader import ( | |
| EpisodeRecord, | |
| LinkResult, | |
| VerdictResult, | |
| grade_episode, | |
| ) | |
| from graders.multi_agent_rewards import ( | |
| RewardInputs, | |
| compute_episode_rewards, | |
| ) | |
| from graders.plausibility_score import compute_queue_plausibility | |
| from models import ( | |
| AdFraudState, | |
| AdReviewAction, | |
| AdReviewObservation, | |
| AuditFlag, | |
| AuditorAction, | |
| AuditorObservation, | |
| AuditReport, | |
| FraudsterAction, | |
| FraudsterObservation, | |
| RefereeState, | |
| ) | |
| from server.environment import InvestigatorEnvironment | |
| from server.evidence_ledger import build_evidence_ledger | |
| logger = logging.getLogger(__name__) | |
| Phase = Literal["fraudster_turn", "investigator_turn", "audit_phase", "done"] | |
| Role = Literal["fraudster", "investigator", "auditor"] | |
| # Module-level grader result for parity with the Investigator env (/grader endpoint). | |
| _last_grader_result: Dict[str, Any] = {} | |
| def get_last_grader_result() -> Dict[str, Any]: | |
| return dict(_last_grader_result) | |
| # Default categories the Fraudster can declare. Combines plausible legit | |
| # categories (so a sophisticated Fraudster can camouflage) with fraud | |
| # templates (so it can propose obvious-fraud or borderline ads). | |
| DEFAULT_ALLOWED_CATEGORIES: Tuple[str, ...] = ( | |
| # Legit camouflage categories | |
| "ecommerce", | |
| "saas", | |
| "local_service", | |
| "education", | |
| "fitness", | |
| # Fraud / borderline templates | |
| "fake_giveaway", | |
| "counterfeit_goods", | |
| "miracle_cure", | |
| "advance_fee", | |
| "fake_crypto", | |
| "celebrity_endorsement_fraud", | |
| "clone_brand", | |
| "gray_area_supplements", | |
| "network_crypto", | |
| "network_ecommerce", | |
| "network_fintech", | |
| "network_health", | |
| ) | |
| class RefereeEnvironment(Environment[Action, Observation, RefereeState]): | |
| """ | |
| Multi-agent referee. Implements the OpenEnv `Environment` contract with | |
| a generic `Action`/`Observation` typing — each WebSocket route passes | |
| role-specific subclasses into `step()` via the `role` kwarg. | |
| Role-aware entry points (preferred): | |
| - `reset_match(seed, task_id, episode_id, **knobs)` | |
| - `step_as_fraudster(action)` | |
| - `step_as_investigator(action)` | |
| - `step_as_auditor(action)` | |
| - `build_<role>_observation()` | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| # Default knobs (overridable via reset kwargs). | |
| DEFAULT_MAX_ROUNDS = 4 | |
| DEFAULT_MAX_PROPOSALS = 5 | |
| # Per-turn action caps. Bumped from (3, 6) to (4, 10) so the | |
| # Investigator can comfortably investigate 2-3 ads per turn AND issue | |
| # verdicts in the same turn without being force-cut to the auditor | |
| # mid-thought (the previous (6) cap was triggering the | |
| # ``max_rounds`` short-circuit on the final round before the | |
| # Investigator could close out pending verdicts). | |
| DEFAULT_MAX_FRAUDSTER_ACTIONS_PER_TURN = 4 | |
| DEFAULT_MAX_INVESTIGATOR_ACTIONS_PER_TURN = 10 | |
| # ------------------------------------------------------------------ | |
| # Lifecycle | |
| # ------------------------------------------------------------------ | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._match_id: str = str(uuid4()) | |
| self._task_id: str = "task_1" | |
| self._rng = random.Random() | |
| self._investigator = InvestigatorEnvironment() | |
| self._episode: Optional[GeneratedEpisode] = None | |
| self._registry: Optional[InvestigationToolRegistry] = None | |
| self._phase: Phase = "fraudster_turn" | |
| self._round_number: int = 0 | |
| self._max_rounds: int = self.DEFAULT_MAX_ROUNDS | |
| self._max_proposals: int = self.DEFAULT_MAX_PROPOSALS | |
| self._max_fraudster_actions_per_turn: int = ( | |
| self.DEFAULT_MAX_FRAUDSTER_ACTIONS_PER_TURN | |
| ) | |
| self._max_investigator_actions_per_turn: int = ( | |
| self.DEFAULT_MAX_INVESTIGATOR_ACTIONS_PER_TURN | |
| ) | |
| self._allowed_categories: List[str] = list(DEFAULT_ALLOWED_CATEGORIES) | |
| self._proposals_used: int = 0 | |
| self._actions_this_turn: int = 0 | |
| # Per-role logs (consumed by the Auditor). | |
| self._fraudster_log: List[Dict[str, Any]] = [] | |
| self._investigator_log: List[Dict[str, Any]] = [] | |
| self._audit_flags: List[AuditFlag] = [] | |
| self._audit_report: Optional[AuditReport] = None | |
| self._fraudster_committed: bool = False | |
| self._done: bool = False | |
| self._end_reason: Optional[str] = None | |
| self._fraudster_reward_total: float = 0.0 | |
| self._investigator_reward_total: float = 0.0 | |
| self._auditor_reward_total: float = 0.0 | |
| self._grader_score: Optional[float] = None | |
| self._per_ad_plausibility: Dict[str, float] = {} | |
| self._audit_ground_truth: Dict[str, int] = {} | |
| self._last_feedback: Dict[Role, str] = { | |
| "fraudster": "", | |
| "investigator": "", | |
| "auditor": "", | |
| } | |
| # Proposal slot_index -> ad_id map, so the Fraudster can modify its | |
| # own prior proposals without knowing the Referee's ad_id scheme. | |
| self._proposal_slot_to_ad_id: Dict[int, str] = {} | |
| # Set inside ``_fraudster_propose_ad`` on success, consumed (and | |
| # cleared) by ``_serialize_fraudster_action`` so the audit log entry | |
| # for a propose_ad always carries the resolved ``ad_id`` and slot | |
| # the env actually allocated, not just the LLM's raw payload (which | |
| # has no ad_id field for propose_ad). | |
| self._last_proposed_ad_id: Optional[str] = None | |
| self._last_proposed_slot: Optional[int] = None | |
| # ------------------------------------------------------------------ | |
| # OpenEnv surface (generic) | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| """ | |
| Generic reset. Returns the *Fraudster* observation because the | |
| Fraudster always goes first. The role-specific endpoints can | |
| also call `build_<role>_observation()` directly. | |
| """ | |
| self.reset_match(seed=seed, episode_id=episode_id, **kwargs) | |
| return self.build_fraudster_observation() | |
| def step( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| """ | |
| Role-aware generic step. Expects `role` in kwargs, dispatches to | |
| the appropriate role-specific step method, and returns that role's | |
| observation. | |
| """ | |
| role: Optional[Role] = kwargs.get("role") | |
| if role == "fraudster": | |
| return self.step_as_fraudster(action) # type: ignore[arg-type] | |
| if role == "investigator": | |
| return self.step_as_investigator(action) # type: ignore[arg-type] | |
| if role == "auditor": | |
| return self.step_as_auditor(action) # type: ignore[arg-type] | |
| raise ValueError( | |
| "RefereeEnvironment.step(action, role=...) requires a role of " | |
| "'fraudster', 'investigator', or 'auditor'." | |
| ) | |
| def state(self) -> RefereeState: | |
| inv_state = self._investigator.state | |
| return RefereeState( | |
| episode_id=self._match_id, | |
| step_count=( | |
| len(self._fraudster_log) | |
| + len(self._investigator_log) | |
| + len(self._audit_flags) | |
| ), | |
| task_id=self._task_id, | |
| phase=self._phase, | |
| round_number=self._round_number, | |
| max_rounds=self._max_rounds, | |
| proposals_used=self._proposals_used, | |
| max_proposals=self._max_proposals, | |
| actions_this_turn=self._actions_this_turn, | |
| max_actions_per_turn=( | |
| self._max_fraudster_actions_per_turn | |
| if self._phase == "fraudster_turn" | |
| else self._max_investigator_actions_per_turn | |
| ), | |
| investigator_state=inv_state.model_dump() if inv_state else {}, | |
| fraudster_proposals=list(self._fraudster_log), | |
| investigator_action_log=list(self._investigator_log), | |
| fraudster_committed=self._fraudster_committed, | |
| audit_report=( | |
| self._audit_report.model_dump() if self._audit_report else None | |
| ), | |
| fraudster_reward=self._fraudster_reward_total, | |
| investigator_reward=self._investigator_reward_total, | |
| auditor_reward=self._auditor_reward_total, | |
| grader_score=self._grader_score, | |
| end_reason=self._end_reason, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Match setup | |
| # ------------------------------------------------------------------ | |
| def reset_match( | |
| self, | |
| *, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| task_id: Optional[str] = None, | |
| max_rounds: Optional[int] = None, | |
| max_proposals: Optional[int] = None, | |
| max_fraudster_actions_per_turn: Optional[int] = None, | |
| max_investigator_actions_per_turn: Optional[int] = None, | |
| allowed_categories: Optional[List[str]] = None, | |
| episode: Optional[GeneratedEpisode] = None, | |
| ) -> None: | |
| """Initialize a fresh match. Sets phase to `fraudster_turn`, round 1.""" | |
| self._match_id = episode_id or str(uuid4()) | |
| self._task_id = task_id or "task_1" | |
| if self._task_id not in TASK_CONFIGS: | |
| self._task_id = "task_1" | |
| effective_seed = ( | |
| seed if seed is not None else hash(uuid4()) & 0xFFFFFFFF | |
| ) | |
| self._rng = random.Random(effective_seed) | |
| # Resolve each knob with precedence: explicit kwarg > TaskConfig curriculum > class default. | |
| task_cfg = TASK_CONFIGS[self._task_id] | |
| def _resolve(arg_value: Optional[int], cfg_attr: str, default: int) -> int: | |
| if arg_value is not None: | |
| return arg_value | |
| cfg_val = getattr(task_cfg, cfg_attr, None) | |
| return cfg_val if cfg_val is not None else default | |
| self._max_rounds = _resolve(max_rounds, "max_rounds", self.DEFAULT_MAX_ROUNDS) | |
| self._max_proposals = _resolve( | |
| max_proposals, "max_proposals", self.DEFAULT_MAX_PROPOSALS | |
| ) | |
| self._max_fraudster_actions_per_turn = _resolve( | |
| max_fraudster_actions_per_turn, | |
| "max_fraudster_actions_per_turn", | |
| self.DEFAULT_MAX_FRAUDSTER_ACTIONS_PER_TURN, | |
| ) | |
| self._max_investigator_actions_per_turn = _resolve( | |
| max_investigator_actions_per_turn, | |
| "max_investigator_actions_per_turn", | |
| self.DEFAULT_MAX_INVESTIGATOR_ACTIONS_PER_TURN, | |
| ) | |
| cfg_categories = getattr(task_cfg, "allowed_fraud_categories", None) | |
| if allowed_categories is not None: | |
| self._allowed_categories = list(allowed_categories) | |
| elif cfg_categories: | |
| self._allowed_categories = list(cfg_categories) | |
| else: | |
| self._allowed_categories = list(DEFAULT_ALLOWED_CATEGORIES) | |
| if episode is not None: | |
| self._episode = episode | |
| else: | |
| self._episode = generate_episode(effective_seed, self._task_id) | |
| self._registry = InvestigationToolRegistry.from_episode(self._episode) | |
| self._investigator.reset( | |
| seed=effective_seed, | |
| episode_id=self._match_id, | |
| task_id=self._task_id, | |
| episode=self._episode, | |
| registry=self._registry, | |
| queue_may_grow=True, | |
| ) | |
| self._phase = "fraudster_turn" | |
| self._round_number = 1 | |
| self._proposals_used = 0 | |
| self._actions_this_turn = 0 | |
| self._fraudster_log = [] | |
| self._investigator_log = [] | |
| self._audit_flags = [] | |
| self._audit_report = None | |
| self._fraudster_committed = False | |
| self._done = False | |
| self._end_reason = None | |
| self._fraudster_reward_total = 0.0 | |
| self._investigator_reward_total = 0.0 | |
| self._auditor_reward_total = 0.0 | |
| self._grader_score = None | |
| self._per_ad_plausibility = {} | |
| self._audit_ground_truth = {} | |
| self._proposal_slot_to_ad_id = {} | |
| self._last_proposed_ad_id = None | |
| self._last_proposed_slot = None | |
| self._last_feedback = { | |
| "fraudster": ( | |
| f"Match started. Round 1 of {self._max_rounds}. " | |
| f"You may propose up to {self._max_proposals} ads total, " | |
| f"{self._max_fraudster_actions_per_turn} actions per turn." | |
| ), | |
| "investigator": ( | |
| "Waiting for Fraudster to finish their turn. The ad queue may " | |
| "grow during this episode as the Fraudster proposes new ads." | |
| ), | |
| "auditor": "Match in progress. Waiting for audit phase.", | |
| } | |
| # ------------------------------------------------------------------ | |
| # Fraudster step handler | |
| # ------------------------------------------------------------------ | |
| def step_as_fraudster(self, action: FraudsterAction) -> FraudsterObservation: | |
| self._guard_phase("fraudster_turn", role="fraudster") | |
| assert self._episode is not None and self._registry is not None | |
| reward = 0.0 | |
| feedback_parts: List[str] = [] | |
| action_type = action.action_type | |
| if action_type == "propose_ad": | |
| reward, msg = self._fraudster_propose_ad(action) | |
| feedback_parts.append(msg) | |
| self._actions_this_turn += 1 | |
| elif action_type == "modify_pending_ad": | |
| reward, msg = self._fraudster_modify_pending_ad(action) | |
| feedback_parts.append(msg) | |
| self._actions_this_turn += 1 | |
| elif action_type == "end_turn": | |
| feedback_parts.append("Fraudster ended turn. Control passes to Investigator.") | |
| self._transition(to="investigator_turn", note="fraudster end_turn") | |
| reward = 0.0 | |
| elif action_type == "commit_final": | |
| feedback_parts.append( | |
| "Fraudster committed (no more proposals). Jumping to audit phase." | |
| ) | |
| self._fraudster_committed = True | |
| self._end_reason = "commit_final" | |
| self._transition(to="audit_phase", note="fraudster commit_final") | |
| reward = 0.0 | |
| else: | |
| feedback_parts.append(f"Unknown Fraudster action_type '{action_type}'.") | |
| reward = -0.05 | |
| self._fraudster_reward_total += reward | |
| self._last_feedback["fraudster"] = " ".join(feedback_parts).strip() | |
| self._fraudster_log.append(self._serialize_fraudster_action(action, reward)) | |
| # Auto-transition guards. | |
| if ( | |
| self._phase == "fraudster_turn" | |
| and action_type in ("propose_ad", "modify_pending_ad") | |
| and self._actions_this_turn >= self._max_fraudster_actions_per_turn | |
| ): | |
| self._transition(to="investigator_turn", note="fraudster action cap") | |
| if ( | |
| self._phase == "fraudster_turn" | |
| and action_type == "propose_ad" | |
| and self._proposals_used >= self._max_proposals | |
| ): | |
| self._last_feedback["fraudster"] += ( | |
| " Proposal budget exhausted — control will pass to Investigator." | |
| ) | |
| self._transition(to="investigator_turn", note="proposal budget exhausted") | |
| return self.build_fraudster_observation(reward=reward) | |
| def _fraudster_propose_ad(self, action: FraudsterAction) -> Tuple[float, str]: | |
| if self._proposals_used >= self._max_proposals: | |
| return -0.05, ( | |
| f"Proposal budget exhausted ({self._proposals_used}/{self._max_proposals})." | |
| ) | |
| if not action.ad_copy or not action.ad_copy.strip(): | |
| return -0.05, "propose_ad requires non-empty `ad_copy`." | |
| if not action.category: | |
| return -0.05, "propose_ad requires `category`." | |
| if action.category not in self._allowed_categories: | |
| return -0.05, ( | |
| f"category '{action.category}' not in allowed_categories. " | |
| f"Use one of: {', '.join(self._allowed_categories)}." | |
| ) | |
| assert self._episode is not None and self._registry is not None | |
| proposal_seed = self._rng.randint(0, 2**31 - 1) | |
| ad = extend_episode_with_proposal( | |
| episode=self._episode, | |
| registry=self._registry, | |
| seed=proposal_seed, | |
| ad_copy=action.ad_copy, | |
| category=action.category, | |
| landing_page_blurb=action.landing_page_blurb, | |
| targeting_summary=action.targeting_summary, | |
| ) | |
| slot_index = self._proposals_used | |
| self._proposal_slot_to_ad_id[slot_index] = ad.ad_id | |
| self._proposals_used += 1 | |
| # Stash so ``_serialize_fraudster_action`` can attach the resolved | |
| # ``ad_id`` + ``slot_index`` to this propose_ad's audit log entry | |
| # (the FraudsterAction itself doesn't carry these — they're env- | |
| # allocated). Without this the auditor sees ``ad_id=None`` for | |
| # every propose_ad, which then poisons downstream Track B checks | |
| # (e.g. ``intrinsic_consistency_check`` cannot key flags onto an | |
| # ad and ``cross_ad_consistency_audit`` cannot dedupe by ad_id). | |
| self._last_proposed_ad_id = ad.ad_id | |
| self._last_proposed_slot = slot_index | |
| self._investigator.notify_queue_grew(ad.ad_id) | |
| feedback = ( | |
| f"Proposal #{slot_index + 1} accepted: ad_id={ad.ad_id}, category={ad.category}. " | |
| f"Queue is now {len(self._episode.ads)} ads." | |
| ) | |
| return 0.02, feedback | |
| def _fraudster_modify_pending_ad(self, action: FraudsterAction) -> Tuple[float, str]: | |
| if action.slot_index is None: | |
| return -0.05, "modify_pending_ad requires `slot_index`." | |
| slot = action.slot_index | |
| if slot not in self._proposal_slot_to_ad_id: | |
| return -0.05, f"Unknown slot_index {slot}. Propose an ad first." | |
| ad_id = self._proposal_slot_to_ad_id[slot] | |
| assert self._episode is not None and self._registry is not None | |
| # Locked once the Investigator has already rendered a verdict. | |
| already_decided = self._investigator.verdicts.get(ad_id, {}).get("verdict") | |
| if already_decided: | |
| return ( | |
| -0.05, | |
| f"Cannot modify {ad_id}: Investigator already rendered verdict " | |
| f"'{already_decided}'.", | |
| ) | |
| target_ad: Optional[Ad] = None | |
| for a in self._episode.ads: | |
| if a.ad_id == ad_id: | |
| target_ad = a | |
| break | |
| if target_ad is None: | |
| return -0.05, f"Internal error: ad {ad_id} not in episode." | |
| changes: List[str] = [] | |
| if action.new_ad_copy is not None and action.new_ad_copy.strip(): | |
| target_ad.ad_copy = action.new_ad_copy.strip()[:2000] | |
| changes.append("ad_copy") | |
| if action.new_landing_page_blurb is not None and action.new_landing_page_blurb.strip(): | |
| lp = self._episode.landing_pages.get(ad_id) | |
| if lp is not None: | |
| from dataclasses import replace | |
| new_lp = replace( | |
| lp, content_summary=action.new_landing_page_blurb.strip()[:2000] | |
| ) | |
| self._episode.landing_pages[ad_id] = new_lp | |
| updated_text = new_lp.to_investigation_text() | |
| self._episode.investigation_data.setdefault(ad_id, {})["landing_page"] = updated_text | |
| self._registry.update_ad(ad_id, {"landing_page": updated_text}) | |
| changes.append("landing_page") | |
| if not changes: | |
| return -0.02, "modify_pending_ad had nothing to change." | |
| return 0.01, f"Modified {ad_id} fields: {', '.join(changes)}." | |
| # ------------------------------------------------------------------ | |
| # Investigator step handler | |
| # ------------------------------------------------------------------ | |
| def step_as_investigator(self, action: AdReviewAction) -> AdReviewObservation: | |
| self._guard_phase("investigator_turn", role="investigator") | |
| assert self._episode is not None | |
| obs = self._investigator.step(action) | |
| reward = float(obs.reward or 0.0) | |
| self._investigator_reward_total += reward | |
| self._investigator_log.append(self._serialize_investigator_action(action, obs)) | |
| self._actions_this_turn += 1 | |
| self._last_feedback["investigator"] = obs.feedback or "" | |
| # Episode termination paths: | |
| # 1. Fraudster already committed AND all ads decided -> audit_phase. | |
| # 2. Max rounds reached AND no more proposals allowed -> audit_phase. | |
| # 3. Investigator budget exhausted (obs.done) -> audit_phase. | |
| # 4. Action cap for this turn hit -> fraudster_turn (next round, unless commit_final). | |
| all_decided = self._all_ads_decided() | |
| inv_done = bool(obs.done) | |
| if inv_done: | |
| self._end_reason = self._end_reason or "investigator_done" | |
| self._transition(to="audit_phase", note="investigator env signalled done") | |
| obs.done = False # match isn't over until Auditor submits | |
| return obs | |
| if all_decided and ( | |
| self._fraudster_committed | |
| or self._round_number >= self._max_rounds | |
| or self._proposals_used >= self._max_proposals | |
| ): | |
| self._end_reason = self._end_reason or "all_decided" | |
| self._transition(to="audit_phase", note="all ads decided") | |
| obs.done = False | |
| return obs | |
| if self._actions_this_turn >= self._max_investigator_actions_per_turn: | |
| if self._round_number >= self._max_rounds or self._fraudster_committed: | |
| self._end_reason = self._end_reason or "max_rounds" | |
| self._transition(to="audit_phase", note="max rounds reached") | |
| obs.done = False | |
| else: | |
| self._round_number += 1 | |
| self._transition(to="fraudster_turn", note="investigator action cap") | |
| # One-line warning when the next investigator turn will be | |
| # the LAST one — gives a slow-to-verdict policy a clear | |
| # signal that pending ads will get auto-approved otherwise. | |
| if self._round_number == self._max_rounds: | |
| self._last_feedback["investigator"] = ( | |
| "Final round next: pending ads not given an explicit " | |
| "verdict will auto-approve at audit time." | |
| ) | |
| obs.done = self._phase == "done" | |
| return obs | |
| def _all_ads_decided(self) -> bool: | |
| if self._episode is None: | |
| return False | |
| verdicts = self._investigator.verdicts | |
| return all(a.ad_id in verdicts for a in self._episode.ads) | |
| # ------------------------------------------------------------------ | |
| # Auditor step handler | |
| # ------------------------------------------------------------------ | |
| def step_as_auditor(self, action: AuditorAction) -> AuditorObservation: | |
| self._guard_phase("audit_phase", role="auditor") | |
| feedback = "" | |
| if action.action_type == "flag_investigator": | |
| flag = AuditFlag( | |
| track="A", | |
| target_ad_id=action.target_ad_id, | |
| flag_type=action.flag_type or "unspecified", | |
| severity=action.severity if action.severity is not None else 0.5, | |
| note=action.note or "", | |
| ) | |
| self._audit_flags.append(flag) | |
| feedback = ( | |
| f"Track A flag recorded: {flag.flag_type} (severity={flag.severity:.2f})." | |
| ) | |
| elif action.action_type == "flag_fraudster": | |
| flag = AuditFlag( | |
| track="B", | |
| target_ad_id=action.target_ad_id, | |
| flag_type=action.flag_type or "unspecified", | |
| severity=action.severity if action.severity is not None else 0.5, | |
| note=action.note or "", | |
| ) | |
| self._audit_flags.append(flag) | |
| feedback = ( | |
| f"Track B flag recorded: {flag.flag_type} (severity={flag.severity:.2f})." | |
| ) | |
| elif action.action_type == "submit_audit_report": | |
| report_payload = action.audit_report or {} | |
| track_a_flags = [f for f in self._audit_flags if f.track == "A"] | |
| track_b_flags = [f for f in self._audit_flags if f.track == "B"] | |
| # Track A/B score *defaults* come from the real graders running | |
| # over the episode record — so even a dumb Auditor that submits an | |
| # empty report gets a principled score. Caller-supplied values | |
| # override these (used by tests and LLM Auditors that compute | |
| # their own). | |
| default_a, default_b = self._compute_default_track_scores() | |
| investigator_score = float( | |
| report_payload.get("investigator_audit_score", default_a) | |
| ) | |
| fraudster_score = float( | |
| report_payload.get("fraudster_plausibility_score", default_b) | |
| ) | |
| investigator_score = min(1.0, max(0.0, investigator_score)) | |
| fraudster_score = min(1.0, max(0.0, fraudster_score)) | |
| self._audit_report = AuditReport( | |
| track_a_flags=track_a_flags, | |
| track_b_flags=track_b_flags, | |
| investigator_audit_score=investigator_score, | |
| fraudster_plausibility_score=fraudster_score, | |
| notes=str(report_payload.get("notes", "") or action.note or "")[:4000], | |
| ) | |
| feedback = ( | |
| "Audit report submitted. " | |
| f"Track A flags: {len(track_a_flags)}. " | |
| f"Track B flags: {len(track_b_flags)}. " | |
| f"investigator_audit_score={investigator_score:.2f}, " | |
| f"fraudster_plausibility_score={fraudster_score:.2f}." | |
| ) | |
| self._finalize_audit() | |
| else: | |
| feedback = f"Unknown Auditor action_type '{action.action_type}'." | |
| self._last_feedback["auditor"] = feedback | |
| return self.build_auditor_observation(feedback=feedback) | |
| def _finalize_audit(self) -> None: | |
| """ | |
| Compute grader score and per-role rewards using the multi-agent reward | |
| model (graders/multi_agent_rewards.py), close out the match, and | |
| transition to `done`. | |
| """ | |
| if self._episode is None: | |
| return | |
| record = self._build_episode_record() | |
| self._grader_score = grade_episode(record) | |
| audit_report = self._audit_report or AuditReport( | |
| track_a_flags=[], | |
| track_b_flags=[], | |
| investigator_audit_score=1.0, | |
| fraudster_plausibility_score=1.0, | |
| notes="", | |
| ) | |
| reward_inputs = RewardInputs( | |
| record=record, | |
| audit_report=audit_report, | |
| fraudster_proposal_log=list(self._fraudster_log), | |
| investigator_action_log=list(self._investigator_log), | |
| investigation_data_seen=( | |
| self._registry.to_dict() if self._registry else {} | |
| ), | |
| fraudster_ad_ids=list(self._proposal_slot_to_ad_id.values()), | |
| ) | |
| rewards = compute_episode_rewards(reward_inputs) | |
| self._fraudster_reward_total = float(rewards["fraudster"]) | |
| self._investigator_reward_total = float(rewards["investigator"]) | |
| self._auditor_reward_total = float(rewards["auditor"]) | |
| self._per_ad_plausibility = dict(rewards.get("per_ad_plausibility") or {}) | |
| self._audit_ground_truth = dict(rewards.get("audit_ground_truth") or {}) | |
| global _last_grader_result | |
| _last_grader_result = { | |
| "match_id": self._match_id, | |
| "task_id": self._task_id, | |
| "grader_score": self._grader_score, | |
| "phase": "done", | |
| "total_steps": ( | |
| len(self._fraudster_log) | |
| + len(self._investigator_log) | |
| + len(self._audit_flags) | |
| ), | |
| "fraudster_reward": self._fraudster_reward_total, | |
| "investigator_reward": self._investigator_reward_total, | |
| "auditor_reward": self._auditor_reward_total, | |
| "per_ad_plausibility": self._per_ad_plausibility, | |
| "audit_ground_truth": self._audit_ground_truth, | |
| "proposals_used": self._proposals_used, | |
| "end_reason": self._end_reason, | |
| "audit_report": ( | |
| self._audit_report.model_dump() if self._audit_report else None | |
| ), | |
| } | |
| self._transition(to="done", note="audit report submitted") | |
| self._done = True | |
| def _compute_default_track_scores(self) -> Tuple[float, float]: | |
| """ | |
| Derive default investigator_audit_score and fraudster_plausibility_score | |
| from the Track A and Track B graders. Used when the Auditor submits | |
| an empty report payload. | |
| """ | |
| if self._episode is None: | |
| return 1.0, 1.0 | |
| record = self._build_episode_record() | |
| investigation_data_seen = ( | |
| self._registry.to_dict() if self._registry else {} | |
| ) | |
| track_a_flags = run_track_a( | |
| record, | |
| investigator_actions=list(self._investigator_log), | |
| investigation_data_seen=investigation_data_seen, | |
| ) | |
| investigator_score = track_a_score(track_a_flags) | |
| _per_ad, _flags, queue_plaus = compute_queue_plausibility( | |
| self._fraudster_log | |
| ) | |
| # If the Fraudster never proposed anything, plausibility doesn't | |
| # apply — treat as 1.0 (no evidence the Fraudster was unrealistic). | |
| return investigator_score, queue_plaus if _per_ad else 1.0 | |
| def _build_episode_record(self) -> EpisodeRecord: | |
| """Assemble an EpisodeRecord from Investigator's view, mirroring R1.""" | |
| assert self._episode is not None | |
| verdicts = self._investigator.verdicts | |
| links = self._investigator.links | |
| inv_state: AdFraudState = self._investigator.state | |
| verdict_results = [] | |
| for ad in self._episode.ads: | |
| v = verdicts.get(ad.ad_id) | |
| if v: | |
| verdict_results.append( | |
| VerdictResult( | |
| ad_id=ad.ad_id, | |
| verdict=v["verdict"], | |
| confidence=v.get("confidence", 0.5), | |
| ground_truth=v["ground_truth"], | |
| auto_approved=v.get("auto_approved", False), | |
| ) | |
| ) | |
| link_results = [ | |
| LinkResult(ad_id_1=l["ad_id_1"], ad_id_2=l["ad_id_2"], correct=l["correct"]) | |
| for l in links | |
| ] | |
| ads_metadata = [ | |
| { | |
| "ad_id": ad.ad_id, | |
| "ground_truth": ad.ground_truth_label, | |
| "severity": ad.severity, | |
| } | |
| for ad in self._episode.ads | |
| ] | |
| return EpisodeRecord( | |
| task_id=self._task_id, | |
| total_steps=inv_state.step_count, | |
| action_budget=self._episode.task_config.action_budget, | |
| verdicts=verdict_results, | |
| links=link_results, | |
| ads_metadata=ads_metadata, | |
| n_fraud_rings=len(self._episode.fraud_rings), | |
| ring_sizes=[len(r.member_ad_ids) for r in self._episode.fraud_rings], | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Observation builders | |
| # ------------------------------------------------------------------ | |
| def build_fraudster_observation( | |
| self, *, reward: float = 0.0 | |
| ) -> FraudsterObservation: | |
| phase = self._phase | |
| done = phase == "done" | |
| if self._episode is None: | |
| return FraudsterObservation( | |
| done=done, | |
| reward=reward, | |
| feedback="No episode loaded. Call reset() first.", | |
| phase=phase, | |
| ) | |
| current_queue = self._build_queue_summary() | |
| prior_verdicts = self._build_verdict_history() | |
| investigations = self._investigator.investigations | |
| rounds_remaining = max(0, self._max_rounds - self._round_number + 1) | |
| actions_left = max( | |
| 0, | |
| self._max_fraudster_actions_per_turn - self._actions_this_turn, | |
| ) if phase == "fraudster_turn" else 0 | |
| my_proposal_signals = self._build_my_proposal_signals() | |
| return FraudsterObservation( | |
| done=done, | |
| reward=reward, | |
| feedback=self._last_feedback["fraudster"], | |
| phase=phase, | |
| task_id=getattr(self._episode.task_config, "task_id", ""), | |
| round_number=self._round_number, | |
| rounds_remaining=rounds_remaining, | |
| proposals_used=self._proposals_used, | |
| proposals_remaining=max(0, self._max_proposals - self._proposals_used), | |
| actions_left_this_turn=actions_left, | |
| current_queue=current_queue, | |
| prior_verdicts=prior_verdicts, | |
| investigation_targets_used=investigations, | |
| allowed_categories=list(self._allowed_categories), | |
| my_proposal_signals=my_proposal_signals, | |
| ) | |
| def _build_my_proposal_signals(self) -> Dict[str, Dict[str, Any]]: | |
| """Per-proposal structured signals for the Fraudster's own ads. | |
| For every Fraudster-proposed ad, expose the auto-assigned underlying | |
| signals (payment_id, registrar, domain, country, account_age_days, | |
| targeting_fingerprint) by reusing the same extraction logic the | |
| Investigator's evidence ledger uses. We synthesise an | |
| "investigations" dict that pretends *all* targets were pulled — the | |
| Fraudster authored these ads, so it is allowed to know everything | |
| the env auto-assigned to them. The Fraudster never sees signals | |
| for synthetic / non-self-proposed ads, only for its own slate. | |
| """ | |
| if self._episode is None: | |
| return {} | |
| proposal_ad_ids = list(self._proposal_slot_to_ad_id.values()) | |
| if not proposal_ad_ids: | |
| return {} | |
| full_targets = [ | |
| "payment_method", | |
| "landing_page", | |
| "targeting_overlap", | |
| "advertiser_history", | |
| ] | |
| ledger = build_evidence_ledger( | |
| episode=self._episode, | |
| registry=self._registry, | |
| ad_ids=proposal_ad_ids, | |
| investigations={ad_id: full_targets for ad_id in proposal_ad_ids}, | |
| ) | |
| slot_by_ad_id = { | |
| ad_id: slot for slot, ad_id in self._proposal_slot_to_ad_id.items() | |
| } | |
| verdicts = self._investigator.verdicts | |
| for ad_id, entry in ledger.items(): | |
| if ad_id in slot_by_ad_id: | |
| entry["slot_index"] = slot_by_ad_id[ad_id] | |
| v = verdicts.get(ad_id) | |
| entry["investigator_verdict"] = ( | |
| v.get("verdict") if v else "pending" | |
| ) | |
| return ledger | |
| def build_investigator_observation(self) -> AdReviewObservation: | |
| obs = self._investigator._build_observation( # noqa: SLF001 | |
| reward=0.0, done=(self._phase == "done") | |
| ) | |
| obs.feedback = ( | |
| self._last_feedback["investigator"] or obs.feedback | |
| ) | |
| return obs | |
| def build_auditor_observation( | |
| self, *, feedback: str = "" | |
| ) -> AuditorObservation: | |
| phase = self._phase | |
| done = phase == "done" | |
| investigation_data_seen: Dict[str, Dict[str, str]] = {} | |
| if self._registry is not None: | |
| investigation_data_seen = self._registry.to_dict() | |
| record: Dict[str, Any] = {} | |
| if self._episode is not None: | |
| record = { | |
| "task_id": self._task_id, | |
| "round_number": self._round_number, | |
| "proposals_used": self._proposals_used, | |
| "end_reason": self._end_reason, | |
| "ads": [ | |
| { | |
| "ad_id": ad.ad_id, | |
| "ad_copy": ad.ad_copy, | |
| "category": ad.category, | |
| "ground_truth": ad.ground_truth_label, | |
| "severity": ad.severity, | |
| "fraud_type": ad.fraud_type, | |
| "difficulty": ad.difficulty, | |
| "is_fraudster_proposal": ad.ad_id | |
| in self._proposal_slot_to_ad_id.values(), | |
| } | |
| for ad in self._episode.ads | |
| ], | |
| "verdicts": [ | |
| {"ad_id": ad_id, **v} | |
| for ad_id, v in self._investigator.verdicts.items() | |
| ], | |
| "links": list(self._investigator.links), | |
| "grader_score": self._grader_score, | |
| "fraud_rings": [ | |
| { | |
| "ring_id": ring.ring_id, | |
| "topology": ring.topology, | |
| "case_name": ring.case_name, | |
| "provenance": ring.provenance, | |
| "member_ad_ids": list(ring.member_ad_ids), | |
| "shared_signal_types": list(ring.shared_signals.keys()), | |
| } | |
| for ring in self._episode.fraud_rings | |
| ], | |
| } | |
| return AuditorObservation( | |
| done=done, | |
| reward=self._auditor_reward_total, | |
| feedback=feedback or self._last_feedback["auditor"], | |
| phase=phase, | |
| full_episode_record=record, | |
| investigator_actions=list(self._investigator_log), | |
| fraudster_proposals=list(self._fraudster_log), | |
| investigation_data_seen=investigation_data_seen, | |
| pending_flags=[f.model_dump() for f in self._audit_flags], | |
| ) | |
| # ------------------------------------------------------------------ | |
| # State-machine helpers | |
| # ------------------------------------------------------------------ | |
| def _guard_phase(self, expected: Phase, *, role: Role) -> None: | |
| if self._phase != expected: | |
| raise PermissionError( | |
| f"{role} cannot act during phase '{self._phase}' " | |
| f"(expected '{expected}')." | |
| ) | |
| def _transition(self, *, to: Phase, note: str) -> None: | |
| if self._phase == to: | |
| return | |
| logger.debug("[referee] %s -> %s (%s)", self._phase, to, note) | |
| self._phase = to | |
| self._actions_this_turn = 0 | |
| def _build_queue_summary(self) -> List[Dict[str, Any]]: | |
| assert self._episode is not None | |
| verdicts = self._investigator.verdicts | |
| proposal_ad_ids = set(self._proposal_slot_to_ad_id.values()) | |
| slot_by_ad_id = { | |
| ad_id: slot for slot, ad_id in self._proposal_slot_to_ad_id.items() | |
| } | |
| out: List[Dict[str, Any]] = [] | |
| for ad in self._episode.ads: | |
| v = verdicts.get(ad.ad_id) | |
| entry = { | |
| "ad_id": ad.ad_id, | |
| "ad_copy": ad.ad_copy, | |
| "category": ad.category, | |
| "status": (v["verdict"] if v else "pending"), | |
| "is_my_proposal": ad.ad_id in proposal_ad_ids, | |
| } | |
| if ad.ad_id in slot_by_ad_id: | |
| entry["slot_index"] = slot_by_ad_id[ad.ad_id] | |
| out.append(entry) | |
| return out | |
| def _build_verdict_history(self) -> List[Dict[str, Any]]: | |
| proposal_ad_ids = set(self._proposal_slot_to_ad_id.values()) | |
| history: List[Dict[str, Any]] = [] | |
| for entry in self._investigator_log: | |
| if entry.get("action_type") != "verdict": | |
| continue | |
| history.append( | |
| { | |
| "ad_id": entry.get("ad_id"), | |
| "verdict": entry.get("verdict"), | |
| "confidence": entry.get("confidence"), | |
| "rationale": entry.get("rationale"), | |
| "was_my_proposal": entry.get("ad_id") in proposal_ad_ids, | |
| } | |
| ) | |
| return history | |
| def _serialize_fraudster_action( | |
| self, action: FraudsterAction, reward: float | |
| ) -> Dict[str, Any]: | |
| payload: Dict[str, Any] = { | |
| "ts": time.time(), | |
| "phase": self._phase, | |
| "round_number": self._round_number, | |
| "action_type": action.action_type, | |
| "ad_copy": action.ad_copy, | |
| "category": action.category, | |
| "landing_page_blurb": action.landing_page_blurb, | |
| "targeting_summary": action.targeting_summary, | |
| "slot_index": action.slot_index, | |
| "new_ad_copy": action.new_ad_copy, | |
| "new_landing_page_blurb": action.new_landing_page_blurb, | |
| "rationale": action.rationale, | |
| "reward": reward, | |
| "ad_id": None, | |
| } | |
| # Enrich queue actions with the env-resolved ad context so the | |
| # auditor + downstream graders can key flags onto a real ad_id and | |
| # see the AD'S CURRENT STATE (not just the LLM's payload, which | |
| # for ``modify_pending_ad`` only carries the *delta* fields). | |
| if action.action_type == "propose_ad" and self._last_proposed_ad_id is not None: | |
| payload["ad_id"] = self._last_proposed_ad_id | |
| payload["slot_index"] = self._last_proposed_slot | |
| self._last_proposed_ad_id = None | |
| self._last_proposed_slot = None | |
| elif ( | |
| action.action_type == "modify_pending_ad" | |
| and action.slot_index is not None | |
| and action.slot_index in self._proposal_slot_to_ad_id | |
| ): | |
| ad_id = self._proposal_slot_to_ad_id[action.slot_index] | |
| payload["ad_id"] = ad_id | |
| ad = self._find_episode_ad(ad_id) | |
| if ad is not None: | |
| # Always inject the ad's CURRENT state — the modify only | |
| # carries deltas, and post-modify the ad's authoritative | |
| # ``ad_copy`` / ``targeting_summary`` live on the | |
| # ``Ad`` object the env mutated in | |
| # ``_fraudster_modify_pending_ad``. | |
| payload.setdefault("category", ad.category) | |
| if not payload.get("ad_copy"): | |
| payload["ad_copy"] = action.new_ad_copy or ad.ad_copy | |
| if not payload.get("targeting_summary"): | |
| payload["targeting_summary"] = ad.targeting_summary | |
| if ( | |
| not payload.get("landing_page_blurb") | |
| and self._episode is not None | |
| ): | |
| lp = self._episode.landing_pages.get(ad_id) | |
| if lp is not None: | |
| payload["landing_page_blurb"] = ( | |
| action.new_landing_page_blurb | |
| or lp.content_summary | |
| ) | |
| return payload | |
| def _find_episode_ad(self, ad_id: str) -> Optional[Ad]: | |
| if self._episode is None: | |
| return None | |
| for ad in self._episode.ads: | |
| if ad.ad_id == ad_id: | |
| return ad | |
| return None | |
| def _serialize_investigator_action( | |
| self, action: AdReviewAction, obs: AdReviewObservation | |
| ) -> Dict[str, Any]: | |
| return { | |
| "ts": time.time(), | |
| "phase": self._phase, | |
| "round_number": self._round_number, | |
| "action_type": action.action_type, | |
| "ad_id": action.ad_id, | |
| "investigation_target": action.investigation_target, | |
| "verdict": action.verdict, | |
| "confidence": action.confidence, | |
| "rationale": action.rationale, | |
| "linked_ad_id": action.linked_ad_id, | |
| "link_reason": action.link_reason, | |
| "reward": float(obs.reward or 0.0), | |
| "findings_excerpt": (obs.feedback or "")[:500], | |
| } | |
| # ------------------------------------------------------------------ | |
| # Introspection helpers for the driver / clients | |
| # ------------------------------------------------------------------ | |
| def phase(self) -> Phase: | |
| return self._phase | |
| def done(self) -> bool: | |
| return self._done | |
| def match_id(self) -> str: | |
| return self._match_id | |
| def episode(self) -> Optional[GeneratedEpisode]: | |
| return self._episode | |
| def registry(self) -> Optional[InvestigationToolRegistry]: | |
| return self._registry | |
| def investigator(self) -> InvestigatorEnvironment: | |
| return self._investigator | |
| def grader_score(self) -> Optional[float]: | |
| return self._grader_score | |