Spaces:
Running
Running
| """ | |
| Core Ad Fraud Investigation Environment. | |
| Implements the OpenEnv Environment interface. The agent reviews a queue of ads, | |
| investigates them, and renders verdicts under a budget constraint. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any, Dict, List, Optional, Set | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..data.ad_generator import ( | |
| TASK_CONFIGS, | |
| Ad, | |
| GeneratedEpisode, | |
| generate_episode, | |
| ) | |
| from ..models import AdFraudState, AdReviewAction, AdReviewObservation | |
| from ..graders.base_grader import EpisodeRecord, LinkResult, VerdictResult, grade_episode | |
| except ImportError: | |
| from data.ad_generator import ( | |
| TASK_CONFIGS, | |
| Ad, | |
| GeneratedEpisode, | |
| generate_episode, | |
| ) | |
| from models import AdFraudState, AdReviewAction, AdReviewObservation | |
| from graders.base_grader import EpisodeRecord, LinkResult, VerdictResult, grade_episode | |
| logger = logging.getLogger(__name__) | |
| # Module-level store so the /grader endpoint can read the last score. | |
| _last_grader_result: Dict[str, Any] = {} | |
| def get_last_grader_result() -> Dict[str, Any]: | |
| return dict(_last_grader_result) | |
| class AdFraudEnvironment( | |
| Environment[AdReviewAction, AdReviewObservation, AdFraudState] | |
| ): | |
| """ | |
| Ad fraud investigation environment. | |
| Each episode is a review session: the agent processes a queue of N ads | |
| within a limited action budget, choosing what to investigate and when | |
| to render verdicts. Unreviewed ads auto-approve at episode end. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._state = AdFraudState(episode_id=str(uuid4()), step_count=0) | |
| self._episode: Optional[GeneratedEpisode] = None | |
| self._verdicts: Dict[str, Dict[str, Any]] = {} | |
| self._links: List[Dict[str, Any]] = [] | |
| self._investigations: Dict[str, List[str]] = {} | |
| self._cumulative_reward: float = 0.0 | |
| self._done = False | |
| self._last_feedback = "" | |
| self._focused_ad_id: Optional[str] = None | |
| # ------------------------------------------------------------------ | |
| # OpenEnv interface | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| **kwargs: Any, | |
| ) -> AdReviewObservation: | |
| task_id = kwargs.get("task_id", "task_1") | |
| if task_id not in TASK_CONFIGS: | |
| task_id = "task_1" | |
| effective_seed = seed if seed is not None else hash(uuid4()) & 0xFFFFFFFF | |
| self._episode = generate_episode(effective_seed, task_id) | |
| config = self._episode.task_config | |
| self._state = AdFraudState( | |
| episode_id=episode_id or str(uuid4()), | |
| step_count=0, | |
| task_id=task_id, | |
| total_ads=config.queue_size, | |
| reviewed_count=0, | |
| remaining_budget=config.action_budget, | |
| verdicts={}, | |
| grader_score=None, | |
| ) | |
| self._verdicts = {} | |
| self._links = [] | |
| self._investigations = {} | |
| self._cumulative_reward = 0.0 | |
| self._done = False | |
| self._last_feedback = "Episode started. Review the ad queue and begin your investigation." | |
| self._focused_ad_id = self._episode.ads[0].ad_id if self._episode.ads else None | |
| return self._build_observation(reward=0.0, done=False) | |
| def step( | |
| self, | |
| action: AdReviewAction, | |
| timeout_s: float | None = None, | |
| **kwargs: Any, | |
| ) -> AdReviewObservation: | |
| if self._done: | |
| return self._build_observation( | |
| reward=0.0, done=True, | |
| feedback_override="Episode is already complete. Call reset() to start a new episode.", | |
| ) | |
| if self._episode is None: | |
| return self._build_observation( | |
| reward=0.0, done=False, | |
| feedback_override="Environment not initialized. Call reset() first.", | |
| ) | |
| self._state.step_count += 1 | |
| ad_ids = {a.ad_id for a in self._episode.ads} | |
| if action.ad_id not in ad_ids: | |
| self._last_feedback = f"Invalid ad_id '{action.ad_id}'. Valid IDs: {', '.join(sorted(ad_ids))}" | |
| return self._build_observation(reward=-0.05, done=False) | |
| if action.action_type == "investigate": | |
| reward = self._handle_investigate(action) | |
| elif action.action_type == "verdict": | |
| reward = self._handle_verdict(action) | |
| elif action.action_type == "link_accounts": | |
| reward = self._handle_link(action) | |
| else: | |
| self._last_feedback = f"Unknown action_type '{action.action_type}'." | |
| reward = -0.05 | |
| self._cumulative_reward += reward | |
| done = self._check_done() | |
| if done and not self._done: | |
| end_reward = self._handle_episode_end() | |
| reward += end_reward | |
| self._cumulative_reward += end_reward | |
| self._done = True | |
| self._state.remaining_budget = max(0, self._state.remaining_budget) | |
| self._state.reviewed_count = len(self._verdicts) | |
| self._state.verdicts = { | |
| ad_id: v.get("verdict", "") for ad_id, v in self._verdicts.items() | |
| } | |
| return self._build_observation(reward=reward, done=self._done) | |
| def state(self) -> AdFraudState: | |
| return self._state | |
| # ------------------------------------------------------------------ | |
| # Action handlers | |
| # ------------------------------------------------------------------ | |
| def _handle_investigate(self, action: AdReviewAction) -> float: | |
| if self._state.remaining_budget <= 0: | |
| self._last_feedback = "No budget remaining. You must render verdicts on remaining ads or end the episode." | |
| return -0.02 | |
| if action.investigation_target is None: | |
| self._last_feedback = "investigation_target is required for action_type='investigate'." | |
| return -0.05 | |
| ad_id = action.ad_id | |
| target = action.investigation_target | |
| prev = self._investigations.setdefault(ad_id, []) | |
| if target in prev: | |
| self._last_feedback = ( | |
| f"You already investigated '{target}' for {ad_id}. " | |
| "Choose a different investigation target or render a verdict." | |
| ) | |
| return -0.02 | |
| if ad_id in self._verdicts: | |
| self._last_feedback = f"You already rendered a verdict on {ad_id}. Choose a different ad." | |
| return -0.02 | |
| self._state.remaining_budget -= 1 | |
| prev.append(target) | |
| self._focused_ad_id = ad_id | |
| findings = self._episode.investigation_data.get(ad_id, {}).get( | |
| target, "No data available for this investigation type." | |
| ) | |
| self._last_feedback = ( | |
| f"Investigation complete: {target} for {ad_id}.\n" | |
| f"--- Findings ---\n{findings}" | |
| ) | |
| return -0.02 | |
| def _handle_verdict(self, action: AdReviewAction) -> float: | |
| ad_id = action.ad_id | |
| if ad_id in self._verdicts: | |
| self._last_feedback = f"You already rendered a verdict on {ad_id}." | |
| return -0.02 | |
| if action.verdict is None: | |
| self._last_feedback = "verdict field is required for action_type='verdict'." | |
| return -0.05 | |
| confidence = action.confidence if action.confidence is not None else 0.5 | |
| ad = self._get_ad(ad_id) | |
| ground_truth = ad.ground_truth_label if ad else "legit" | |
| severity = ad.severity if ad else 0.0 | |
| self._verdicts[ad_id] = { | |
| "verdict": action.verdict, | |
| "confidence": confidence, | |
| "ground_truth": ground_truth, | |
| } | |
| reward = self._compute_verdict_reward(action.verdict, ground_truth, severity, confidence) | |
| pending = [a.ad_id for a in self._episode.ads if a.ad_id not in self._verdicts] | |
| self._last_feedback = ( | |
| f"Verdict recorded for {ad_id}: {action.verdict} " | |
| f"(confidence: {confidence:.2f}). " | |
| f"{len(pending)} ad(s) remaining in queue." | |
| ) | |
| if pending: | |
| self._focused_ad_id = pending[0] | |
| return reward | |
| def _handle_link(self, action: AdReviewAction) -> float: | |
| if action.linked_ad_id is None: | |
| self._last_feedback = "linked_ad_id is required for action_type='link_accounts'." | |
| return -0.05 | |
| ad_ids = {a.ad_id for a in self._episode.ads} | |
| if action.linked_ad_id not in ad_ids: | |
| self._last_feedback = f"Invalid linked_ad_id '{action.linked_ad_id}'." | |
| return -0.05 | |
| if action.ad_id == action.linked_ad_id: | |
| self._last_feedback = "Cannot link an ad to itself." | |
| return -0.05 | |
| link_key = tuple(sorted([action.ad_id, action.linked_ad_id])) | |
| existing = {tuple(sorted([l["ad_id_1"], l["ad_id_2"]])) for l in self._links} | |
| if link_key in existing: | |
| self._last_feedback = f"Link between {action.ad_id} and {action.linked_ad_id} already recorded." | |
| return -0.02 | |
| is_correct = self._check_link_correct(action.ad_id, action.linked_ad_id) | |
| self._links.append({ | |
| "ad_id_1": action.ad_id, | |
| "ad_id_2": action.linked_ad_id, | |
| "reason": action.link_reason or "", | |
| "correct": is_correct, | |
| }) | |
| self._last_feedback = ( | |
| f"Network link recorded: {action.ad_id} <-> {action.linked_ad_id}. " | |
| f"Reason: {action.link_reason or 'not specified'}." | |
| ) | |
| return 0.4 if is_correct else -0.25 | |
| # ------------------------------------------------------------------ | |
| # Reward computation | |
| # ------------------------------------------------------------------ | |
| def _compute_verdict_reward( | |
| self, verdict: str, ground_truth: str, severity: float, confidence: float | |
| ) -> float: | |
| if verdict == "reject" and ground_truth == "fraud": | |
| return 0.3 + 0.1 * severity | |
| elif verdict == "approve" and ground_truth == "legit": | |
| return 0.1 | |
| elif verdict == "escalate" and ground_truth == "escalate": | |
| return 0.15 | |
| elif verdict == "reject" and ground_truth == "legit": | |
| return -0.35 | |
| elif verdict == "approve" and ground_truth == "fraud": | |
| return -0.5 | |
| elif verdict == "escalate": | |
| return -0.05 | |
| elif verdict == "approve" and ground_truth == "escalate": | |
| return -0.15 | |
| elif verdict == "reject" and ground_truth == "escalate": | |
| return -0.1 | |
| else: | |
| return -0.05 | |
| def _handle_episode_end(self) -> float: | |
| """Apply end-of-episode adjustments for unreviewed ads, then delegate to graders.""" | |
| unreviewed_fraud = 0 | |
| for ad in self._episode.ads: | |
| if ad.ad_id not in self._verdicts: | |
| self._verdicts[ad.ad_id] = { | |
| "verdict": "approve", | |
| "confidence": 0.0, | |
| "ground_truth": ad.ground_truth_label, | |
| "auto_approved": True, | |
| } | |
| if ad.ground_truth_label == "fraud": | |
| unreviewed_fraud += 1 | |
| record = self._build_episode_record() | |
| grader_score = grade_episode(record) | |
| self._state.grader_score = grader_score | |
| reviewed_count = len([v for v in self._verdicts.values() if not v.get("auto_approved")]) | |
| total_ads = len(self._episode.ads) | |
| total_correct = sum( | |
| 1 for v in self._verdicts.values() | |
| if not v.get("auto_approved") | |
| and ( | |
| (v["verdict"] == "reject" and v["ground_truth"] == "fraud") | |
| or (v["verdict"] == "approve" and v["ground_truth"] == "legit") | |
| or (v["verdict"] == "escalate" and v["ground_truth"] == "escalate") | |
| ) | |
| ) | |
| false_positives = sum( | |
| 1 for v in self._verdicts.values() | |
| if not v.get("auto_approved") | |
| and v["verdict"] == "reject" and v["ground_truth"] == "legit" | |
| ) | |
| false_negatives = sum( | |
| 1 for v in self._verdicts.values() | |
| if not v.get("auto_approved") | |
| and v["verdict"] == "approve" and v["ground_truth"] == "fraud" | |
| ) | |
| correct_links = sum(1 for l in self._links if l.get("correct")) | |
| incorrect_links = sum(1 for l in self._links if not l.get("correct")) | |
| global _last_grader_result | |
| _last_grader_result = { | |
| "task_id": self._state.task_id, | |
| "grader_score": grader_score, | |
| "episode_id": self._state.episode_id, | |
| "total_steps": self._state.step_count, | |
| "verdicts_rendered": reviewed_count, | |
| "correct_decisions": total_correct, | |
| "false_positives": false_positives, | |
| "false_negatives": false_negatives, | |
| "auto_approved": total_ads - reviewed_count, | |
| "unreviewed_fraud": unreviewed_fraud, | |
| "network_links_correct": correct_links, | |
| "network_links_incorrect": incorrect_links, | |
| } | |
| feedback_lines = [ | |
| f"Episode complete. Grader score: {grader_score:.3f}/1.000", | |
| f"Verdicts rendered: {reviewed_count}/{total_ads}", | |
| f"Correct decisions: {total_correct}/{reviewed_count}", | |
| f"False positives (legit rejected): {false_positives}", | |
| f"False negatives (fraud approved): {false_negatives}", | |
| f"Unreviewed ads auto-approved: {unreviewed_fraud}", | |
| ] | |
| if self._links: | |
| feedback_lines.append( | |
| f"Network links: {correct_links} correct, {incorrect_links} incorrect" | |
| ) | |
| self._last_feedback = "\n".join(feedback_lines) | |
| return 0.0 | |
| def _build_episode_record(self) -> EpisodeRecord: | |
| """Convert internal state into an EpisodeRecord for the grader.""" | |
| verdict_results = [] | |
| for ad in self._episode.ads: | |
| v = self._verdicts.get(ad.ad_id) | |
| if v: | |
| verdict_results.append(VerdictResult( | |
| ad_id=ad.ad_id, | |
| verdict=v["verdict"], | |
| confidence=v.get("confidence", 0.5), | |
| ground_truth=v["ground_truth"], | |
| auto_approved=v.get("auto_approved", False), | |
| )) | |
| link_results = [ | |
| LinkResult(ad_id_1=l["ad_id_1"], ad_id_2=l["ad_id_2"], correct=l["correct"]) | |
| for l in self._links | |
| ] | |
| ads_metadata = [ | |
| {"ad_id": ad.ad_id, "ground_truth": ad.ground_truth_label, "severity": ad.severity} | |
| for ad in self._episode.ads | |
| ] | |
| return EpisodeRecord( | |
| task_id=self._state.task_id, | |
| total_steps=self._state.step_count, | |
| action_budget=self._episode.task_config.action_budget, | |
| verdicts=verdict_results, | |
| links=link_results, | |
| ads_metadata=ads_metadata, | |
| n_fraud_rings=len(self._episode.fraud_rings), | |
| ring_sizes=[len(r.member_ad_ids) for r in self._episode.fraud_rings], | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Helpers | |
| # ------------------------------------------------------------------ | |
| def _check_done(self) -> bool: | |
| if self._episode is None: | |
| return True | |
| all_reviewed = all( | |
| ad.ad_id in self._verdicts for ad in self._episode.ads | |
| ) | |
| steps_exhausted = self._state.step_count >= self._episode.task_config.action_budget | |
| return all_reviewed or steps_exhausted | |
| def _check_link_correct(self, ad_id_1: str, ad_id_2: str) -> bool: | |
| """Check if two ads share a fraud ring.""" | |
| for ring in self._episode.fraud_rings: | |
| if ad_id_1 in ring.member_ad_ids and ad_id_2 in ring.member_ad_ids: | |
| return True | |
| return False | |
| def _get_ad(self, ad_id: str) -> Optional[Ad]: | |
| if self._episode is None: | |
| return None | |
| for ad in self._episode.ads: | |
| if ad.ad_id == ad_id: | |
| return ad | |
| return None | |
| def _build_observation( | |
| self, | |
| reward: float, | |
| done: bool, | |
| feedback_override: str | None = None, | |
| ) -> AdReviewObservation: | |
| feedback = feedback_override or self._last_feedback | |
| if self._episode is None: | |
| return AdReviewObservation( | |
| done=done, | |
| reward=reward, | |
| queue_summary="No episode loaded.", | |
| current_ad_info="", | |
| investigation_findings="", | |
| verdict_history_summary="", | |
| feedback=feedback, | |
| available_ads=[], | |
| queue_status={}, | |
| ) | |
| config = self._episode.task_config | |
| pending = [a for a in self._episode.ads if a.ad_id not in self._verdicts] | |
| reviewed = [a for a in self._episode.ads if a.ad_id in self._verdicts] | |
| steps_remaining = max(0, config.action_budget - self._state.step_count) | |
| queue_summary = ( | |
| f"Task: {config.name} ({config.difficulty})\n" | |
| f"Total ads: {config.queue_size} | " | |
| f"Reviewed: {len(reviewed)} | " | |
| f"Pending: {len(pending)} | " | |
| f"Steps remaining: {steps_remaining}/{config.action_budget} | " | |
| f"Investigation budget: {self._state.remaining_budget} | " | |
| f"Step: {self._state.step_count}" | |
| ) | |
| current_ad_info = "" | |
| if self._focused_ad_id and not done: | |
| ad = self._get_ad(self._focused_ad_id) | |
| if ad and ad.ad_id not in self._verdicts: | |
| signals = ", ".join(ad.initial_risk_signals) if ad.initial_risk_signals else "None" | |
| investigated = self._investigations.get(ad.ad_id, []) | |
| inv_status = ", ".join(investigated) if investigated else "None yet" | |
| # Contextual metadata visible before investigation | |
| profile = self._episode.advertiser_profiles.get(ad.ad_id) | |
| meta_lines = [] | |
| if profile: | |
| meta_lines.append(f"Advertiser country: {profile.country}") | |
| meta_lines.append(f"Account age: {profile.account_age_days} days") | |
| if profile.account_age_days < 30: | |
| meta_lines.append("Flag: New account (< 30 days)") | |
| context_meta = "\n".join(meta_lines) | |
| current_ad_info = ( | |
| f"=== Ad in Focus: {ad.ad_id} ===\n" | |
| f"Category: {ad.category}\n" | |
| f"Ad copy: \"{ad.ad_copy}\"\n" | |
| f"Targeting: {ad.targeting_summary}\n" | |
| f"Initial risk signals: {signals}\n" | |
| f"{context_meta}\n" | |
| f"Investigations completed: {inv_status}\n" | |
| f"Available investigation targets: advertiser_history, landing_page, " | |
| f"payment_method, targeting_overlap, creative_similarity, campaign_structure" | |
| ) | |
| investigation_findings = "" | |
| for ad_id, targets in self._investigations.items(): | |
| for target in targets: | |
| finding = self._episode.investigation_data.get(ad_id, {}).get(target, "") | |
| if finding: | |
| investigation_findings += f"\n[{ad_id} / {target}]\n{finding}\n" | |
| manual_verdicts = { | |
| ad_id: v for ad_id, v in self._verdicts.items() | |
| if not v.get("auto_approved") | |
| } | |
| if manual_verdicts: | |
| counts = {"approve": 0, "reject": 0, "escalate": 0} | |
| by_decision = {"approve": [], "reject": [], "escalate": []} | |
| for ad_id, v in manual_verdicts.items(): | |
| counts[v["verdict"]] = counts.get(v["verdict"], 0) + 1 | |
| by_decision[v["verdict"]].append(ad_id) | |
| summary_parts = [f"{c} {k}" for k, c in counts.items() if c > 0] | |
| verdict_lines = [ | |
| f"Reviewed {len(manual_verdicts)} ad(s): {', '.join(summary_parts)}." | |
| ] | |
| for decision in ("reject", "approve", "escalate"): | |
| if by_decision[decision]: | |
| verdict_lines.append( | |
| f" {decision}: {', '.join(by_decision[decision])}" | |
| ) | |
| verdict_history_summary = "\n".join(verdict_lines) | |
| else: | |
| verdict_history_summary = "No verdicts yet." | |
| available_ads = [a.ad_id for a in pending] | |
| queue_status = { | |
| "total_ads": config.queue_size, | |
| "reviewed": len(reviewed), | |
| "pending": len(pending), | |
| "investigation_budget": self._state.remaining_budget, | |
| "steps_remaining": steps_remaining, | |
| "step": self._state.step_count, | |
| "task_id": config.task_id, | |
| } | |
| return AdReviewObservation( | |
| done=done, | |
| reward=reward, | |
| queue_summary=queue_summary, | |
| current_ad_info=current_ad_info, | |
| investigation_findings=investigation_findings.strip(), | |
| verdict_history_summary=verdict_history_summary, | |
| feedback=feedback, | |
| available_ads=available_ads, | |
| queue_status=queue_status, | |
| ) | |