""" CounterFeint — multi-agent FraudArena client library. Exports: FraudsterClient WebSocket client for `/ws/fraudster` InvestigatorClient WebSocket client for `/ws/investigator` AuditorClient WebSocket client for `/ws/auditor` MatchClient Convenience coordinator that owns all three WS connections, shares a single `match_id`, and exposes a flat async API. AdFraudEnv Legacy single-agent client (R1 compatibility). Speaks to `/ws` (Investigator-only). Kept so existing R1 inference / baseline scripts run without change. Example (three-agent): async with MatchClient("ws://localhost:8000") as match: await match.reset(seed=42, task_id="task_1") # match.fraudster.step(...), match.investigator.step(...), etc. state = await match.state() Example (R1 single-agent): env = AdFraudEnv(base_url="http://localhost:8000").sync() env.connect() result = env.reset(seed=42, task_id="task_1") result = env.step(AdReviewAction( action_type="verdict", ad_id="ad_001", verdict="approve", confidence=0.8, )) """ from __future__ import annotations import asyncio import json import logging from types import TracebackType from typing import Any, Dict, Optional, Type import websockets from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from .models import ( AdFraudState, AdReviewAction, AdReviewObservation, AuditorAction, AuditorObservation, FraudsterAction, FraudsterObservation, RefereeState, ) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Legacy R1 single-agent client (unchanged) # --------------------------------------------------------------------------- class AdFraudEnv(EnvClient[AdReviewAction, AdReviewObservation, AdFraudState]): """ R1 single-agent WebSocket client (Investigator-only). Kept for backwards compatibility with the Round-1 baseline script and for OpenEnv's standard `/ws` route. For Round-2 multi-agent workflows use `MatchClient` or the per-role clients. """ def _step_payload(self, action: AdReviewAction) -> Dict[str, Any]: return action.model_dump(exclude_none=True, exclude={"metadata"}) def _parse_result( self, payload: Dict[str, Any] ) -> StepResult[AdReviewObservation]: obs_data = payload.get("observation", {}) reward = payload.get("reward", 0.0) or 0.0 done = payload.get("done", False) observation = AdReviewObservation( done=done, reward=reward, queue_summary=obs_data.get("queue_summary", ""), current_ad_info=obs_data.get("current_ad_info", ""), investigation_findings=obs_data.get("investigation_findings", ""), verdict_history_summary=obs_data.get("verdict_history_summary", ""), feedback=obs_data.get("feedback", ""), available_ads=obs_data.get("available_ads", []), queue_status=obs_data.get("queue_status", {}), queue_may_grow=obs_data.get("queue_may_grow", False), evidence_ledger=obs_data.get("evidence_ledger", {}), queue_digest=obs_data.get("queue_digest", []), decided_ads=obs_data.get("decided_ads", []), metadata=obs_data.get("metadata", {}), ) return StepResult(observation=observation, reward=reward, done=done) def _parse_state(self, payload: Dict[str, Any]) -> AdFraudState: return AdFraudState( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), task_id=payload.get("task_id", ""), total_ads=payload.get("total_ads", 0), reviewed_count=payload.get("reviewed_count", 0), remaining_budget=payload.get("remaining_budget", 0), verdicts=payload.get("verdicts", {}), grader_score=payload.get("grader_score"), ) # --------------------------------------------------------------------------- # Shared multi-agent client infrastructure # --------------------------------------------------------------------------- class MultiAgentProtocolError(RuntimeError): """Raised when the server returns a protocol-level error (validation, phase violation, ...).""" def __init__(self, message: str, code: str = "execution_error") -> None: super().__init__(message) self.code = code def _http_to_ws(base_url: str) -> str: """Normalize an http(s)://... base URL to ws(s)://... (idempotent).""" if base_url.startswith("http://"): return "ws://" + base_url[len("http://"):] if base_url.startswith("https://"): return "wss://" + base_url[len("https://"):] return base_url class _RoleClient: """ Async WebSocket client for a single role. All three role-specific subclasses share this logic. The only variation is the WS path (`/ws/fraudster`, `/ws/investigator`, `/ws/auditor`) and the action/observation Pydantic types (enforced via generics in the subclasses). """ ws_path: str = "" # override in subclasses action_cls: Type[Any] = Any # type: ignore[assignment] observation_cls: Type[Any] = Any # type: ignore[assignment] def __init__(self, base_url: str, *, timeout: float = 30.0) -> None: self._ws_base = _http_to_ws(base_url.rstrip("/")) self._url = f"{self._ws_base}{self.ws_path}" self._timeout = timeout self._ws: Optional[Any] = None self._match_id: Optional[str] = None @property def match_id(self) -> Optional[str]: return self._match_id @property def connected(self) -> bool: return self._ws is not None async def connect(self) -> None: if self._ws is not None: return self._ws = await websockets.connect( self._url, open_timeout=self._timeout, ping_interval=None, ) async def close(self) -> None: if self._ws is None: return try: await self._send("close", {}) except Exception: pass try: await self._ws.close() except Exception: pass self._ws = None async def __aenter__(self) -> "_RoleClient": await self.connect() return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ) -> None: await self.close() async def reset(self, **kwargs: Any) -> Dict[str, Any]: """ Create a new match. Only one role needs to call `reset` per match (typically the Fraudster). Returns the role's initial observation. """ await self.connect() data = {k: v for k, v in kwargs.items() if v is not None} await self._send("reset", data) payload = await self._recv_observation() self._match_id = payload.get("match_id") return payload async def join(self, match_id: str) -> Dict[str, Any]: """Attach to an existing match (another role already called `reset`).""" await self.connect() await self._send("join", {"match_id": match_id}) payload = await self._recv_observation() self._match_id = match_id return payload async def step(self, action: Any) -> Dict[str, Any]: """Execute an action. Raises MultiAgentProtocolError on validation / phase errors.""" await self._require_connected() payload = ( action.model_dump(exclude_none=True) if hasattr(action, "model_dump") else dict(action) ) await self._send("step", payload) return await self._recv_observation() async def obs(self) -> Dict[str, Any]: """Return the current observation without stepping.""" await self._require_connected() await self._send("obs", {}) return await self._recv_observation() async def state(self) -> Dict[str, Any]: """Return the full shared `RefereeState` for this match.""" await self._require_connected() await self._send("state", {}) msg = await self._recv_any() if msg.get("type") != "state": raise MultiAgentProtocolError( f"expected state response, got {msg!r}", code="protocol_error" ) return msg["data"] async def _send(self, msg_type: str, data: Dict[str, Any]) -> None: assert self._ws is not None await self._ws.send(json.dumps({"type": msg_type, "data": data})) async def _recv_any(self) -> Dict[str, Any]: assert self._ws is not None raw = await asyncio.wait_for(self._ws.recv(), timeout=self._timeout) msg = json.loads(raw) if msg.get("type") == "error": err = msg.get("data", {}) raise MultiAgentProtocolError( err.get("message", "unknown error"), code=err.get("code", "execution_error"), ) return msg async def _recv_observation(self) -> Dict[str, Any]: msg = await self._recv_any() if msg.get("type") != "observation": raise MultiAgentProtocolError( f"expected observation, got {msg!r}", code="protocol_error" ) return msg["data"] async def _require_connected(self) -> None: if self._ws is None: raise RuntimeError( f"{type(self).__name__} is not connected; call connect()/reset()/join() first" ) class FraudsterClient(_RoleClient): """Fraudster agent (proposes / modifies ads). Connects to `/ws/fraudster`.""" ws_path = "/ws/fraudster" action_cls = FraudsterAction observation_cls = FraudsterObservation class InvestigatorClient(_RoleClient): """Investigator agent (investigates + verdicts). Connects to `/ws/investigator`.""" ws_path = "/ws/investigator" action_cls = AdReviewAction observation_cls = AdReviewObservation class AuditorClient(_RoleClient): """Auditor agent (audits both peers post-hoc). Connects to `/ws/auditor`.""" ws_path = "/ws/auditor" action_cls = AuditorAction observation_cls = AuditorObservation # --------------------------------------------------------------------------- # MatchClient — convenience coordinator # --------------------------------------------------------------------------- class MatchClient: """ Convenience wrapper owning three role-specific WS clients plus a shared `match_id`. Handles the dance of: 1. Fraudster connects + resets a match 2. Investigator and Auditor join using the returned `match_id` Use as an async context manager, or call `connect()`/`close()` manually. """ def __init__(self, base_url: str, *, timeout: float = 30.0) -> None: self.base_url = base_url self.fraudster = FraudsterClient(base_url, timeout=timeout) self.investigator = InvestigatorClient(base_url, timeout=timeout) self.auditor = AuditorClient(base_url, timeout=timeout) self._match_id: Optional[str] = None @property def match_id(self) -> Optional[str]: return self._match_id async def __aenter__(self) -> "MatchClient": return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ) -> None: await self.close() async def reset(self, **kwargs: Any) -> Dict[str, Any]: """ Open the three WS connections, create a match via the Fraudster, and have the Investigator + Auditor join. Returns the Fraudster's initial observation (includes `match_id`). """ fraud_obs = await self.fraudster.reset(**kwargs) match_id = fraud_obs.get("match_id") if not match_id: raise MultiAgentProtocolError( "server did not return match_id on reset", code="protocol_error" ) self._match_id = match_id await self.investigator.join(match_id) await self.auditor.join(match_id) return fraud_obs async def state(self) -> RefereeState: """Return the shared RefereeState as a typed Pydantic object.""" if self._match_id is None: raise RuntimeError("no active match; call reset() first") data = await self.fraudster.state() return RefereeState.model_validate(data) async def close(self) -> None: """Close all three WS connections. Safe to call multiple times.""" for client in (self.fraudster, self.investigator, self.auditor): try: await client.close() except Exception: logger.debug("client.close() raised", exc_info=True) __all__ = [ "AdFraudEnv", "AuditorClient", "FraudsterClient", "InvestigatorClient", "MatchClient", "MultiAgentProtocolError", ]