Spaces:
Sleeping
Sleeping
| """ | |
| Pydantic models for the Invoice Processing Pipeline environment. | |
| Action: Agent submits extracted/cleaned/reconciled invoice data as JSON. | |
| Observation: Agent receives raw invoice text, feedback, and task context. | |
| State: Tracks episode progress, attempts, and scores. | |
| """ | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Field | |
| # --------------------------------------------------------------------------- | |
| # Action | |
| # --------------------------------------------------------------------------- | |
| class InvoiceAction(BaseModel): | |
| """Action the agent submits each step.""" | |
| extracted_data: Dict[str, Any] = Field( | |
| ..., | |
| description=( | |
| "JSON object with extracted/cleaned invoice fields. " | |
| "Structure depends on the task. " | |
| "Easy: {vendor, date, currency, total, line_items: [{description, qty, unit_price, amount}]}. " | |
| "Medium: {invoices: [{vendor, date, currency, total, line_items}]} (batch of cleaned invoices). " | |
| "Hard: {invoices: [...], discrepancies: [{invoice_idx, type, detail, expected, actual}]}. " | |
| "Adversarial: same schema as easy — {vendor, date, currency, total, line_items}. " | |
| "Negotiate: either {'question': str} to ask a clarification, or the full extraction " | |
| "(same schema as easy). " | |
| "Supply_chain: {'anomalies': [{'delivery_id', 'anomaly_type', 'detail'}]}." | |
| ), | |
| ) | |
| explanation: str = Field( | |
| default="", | |
| description="Optional reasoning about extraction or cleaning decisions.", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Observation | |
| # --------------------------------------------------------------------------- | |
| class InvoiceObservation(BaseModel): | |
| """What the agent sees each turn.""" | |
| raw_text: str = Field(..., description="Raw invoice text (OCR-style or CSV-style)") | |
| task_id: str = Field(..., description="easy | medium | hard | expert | adversarial | negotiate | supply_chain | long_horizon | personalized | curriculum") | |
| difficulty: str = Field(..., description="Same as task_id") | |
| task_description: str = Field(..., description="What the agent should do") | |
| attempt_number: int = Field(default=0, description="Current attempt (0 = just reset)") | |
| max_attempts: int = Field(default=5, description="Max allowed attempts") | |
| feedback: str = Field(default="", description="Detailed grader feedback from last attempt") | |
| hint: str = Field(default="", description="Hint shown after 2+ failed attempts") | |
| reference_data: str = Field( | |
| default="", | |
| description="For hard task: purchase order data to reconcile against", | |
| ) | |
| reward_breakdown: Optional[Dict[str, Any]] = Field( | |
| default=None, | |
| description=( | |
| "Per-field score breakdown for easy, adversarial, and negotiate tasks. " | |
| "Example: {'vendor': {'score': 0.15, 'max': 0.15, 'status': 'correct'}, " | |
| "'date': {'score': 0.0, 'max': 0.10, 'status': 'wrong'}, ...}" | |
| ), | |
| ) | |
| conversation_history: List[Dict[str, Any]] = Field( | |
| default_factory=list, | |
| description="For negotiate task: list of {'role': 'agent'|'env', 'content': str} turns.", | |
| ) | |
| phase: Optional[int] = Field( | |
| default=None, | |
| description="For long_horizon task: current phase (1=extract, 2=reconcile, 3=audit, 4=forecast).", | |
| ) | |
| phase_context: Optional[str] = Field( | |
| default=None, | |
| description="For long_horizon task: accumulated findings from prior phases passed to next phase.", | |
| ) | |
| agent_profile: Optional[Dict[str, Any]] = Field( | |
| default=None, | |
| description="For personalized task: agent's historical performance profile used to adapt difficulty.", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # State | |
| # --------------------------------------------------------------------------- | |
| class InvoiceState(BaseModel): | |
| """Internal episode state.""" | |
| episode_id: str = Field(default="") | |
| task_id: str = Field(default="easy") | |
| step_count: int = Field(default=0) | |
| done: bool = Field(default=False) | |
| last_reward: float = Field(default=0.0) | |
| best_reward: float = Field(default=0.0) | |
| rewards: List[float] = Field(default_factory=list) | |
| conversation_history: List[Dict[str, Any]] = Field(default_factory=list) | |
| clarification_count: int = Field(default=0) | |
| # Long-horizon: tracks which phase and accumulated context | |
| phase: int = Field(default=1) | |
| phase_scores: List[float] = Field(default_factory=list) | |
| phase_context: str = Field(default="") | |
| # Personalized: tracks agent weak areas across steps | |
| agent_profile: Dict[str, Any] = Field(default_factory=dict) | |
| # --------------------------------------------------------------------------- | |
| # Supply Chain (documentation model) | |
| # --------------------------------------------------------------------------- | |
| class SupplyChainAnomalyItem(BaseModel): | |
| delivery_id: str | |
| anomaly_type: str # quantity_shortfall | price_spike | unauthorized_substitution | phantom_delivery | |
| detail: str | |
| class SupplyChainAction(BaseModel): | |
| anomalies: List[SupplyChainAnomalyItem] | |