invoiceguard-code / client.py
piyush-mk's picture
Sync InvoiceGuard code for GRPO training job
9a88af0 verified
"""InvoiceGuard Environment Client."""
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
try:
from .models import (
InvoiceGuardAction,
InvoiceGuardObservation,
InvoiceGuardState,
)
except (ImportError, ModuleNotFoundError):
from models import (
InvoiceGuardAction,
InvoiceGuardObservation,
InvoiceGuardState,
)
class InvoiceGuardEnv(
EnvClient[InvoiceGuardAction, InvoiceGuardObservation, InvoiceGuardState]
):
"""
Client for the InvoiceGuard Environment.
Example:
>>> async with InvoiceGuardEnv(base_url="http://localhost:8000") as client:
... result = await client.reset(task_id="task_1_clean_match")
... print(result.observation.invoice_summary)
"""
def _step_payload(self, action: InvoiceGuardAction) -> Dict:
return action.model_dump(exclude_none=True)
def _parse_result(self, payload: Dict) -> StepResult[InvoiceGuardObservation]:
obs_data = payload.get("observation", {})
observation = InvoiceGuardObservation(
case_id=obs_data.get("case_id", ""),
task_id=obs_data.get("task_id", ""),
difficulty=obs_data.get("difficulty", ""),
invoice_summary=obs_data.get("invoice_summary", ""),
goal=obs_data.get("goal", ""),
available_actions=obs_data.get("available_actions", []),
revealed_documents=obs_data.get("revealed_documents", []),
findings=obs_data.get("findings", []),
remaining_steps=obs_data.get("remaining_steps", 0),
last_action_result=obs_data.get("last_action_result", ""),
last_action_error=obs_data.get("last_action_error", False),
warnings=obs_data.get("warnings", []),
grader_result=obs_data.get("grader_result", {}),
done=payload.get("done", False),
reward=payload.get("reward"),
metadata=obs_data.get("metadata", {}),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> InvoiceGuardState:
return InvoiceGuardState(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
task_id=payload.get("task_id", ""),
difficulty=payload.get("difficulty", ""),
case_id=payload.get("case_id", ""),
actions_taken=payload.get("actions_taken", []),
documents_revealed=payload.get("documents_revealed", []),
findings_collected=payload.get("findings_collected", []),
is_finalized=payload.get("is_finalized", False),
cumulative_reward=payload.get("cumulative_reward", 0.0),
)