fix(orchestrator): P2 Round Counter Semantic Mismatch - Semantic Progress Tracking (#132)
efd0997
unverified
| """ | |
| Advanced Orchestrator using Microsoft Agent Framework. | |
| This orchestrator uses the ChatAgent pattern from Microsoft's agent-framework-core | |
| package for multi-agent coordination. It provides richer orchestration capabilities | |
| including specialized agents (Search, Hypothesis, Judge, Report) coordinated by | |
| a manager agent. | |
| Note: Previously named 'orchestrator_magentic.py' - renamed to eliminate confusion | |
| with the 'magentic' PyPI package (which is a different library). | |
| Design Patterns: | |
| - Mediator: Manager agent coordinates between specialized agents | |
| - Strategy: Different agents implement different strategies for their tasks | |
| - Observer: Event stream allows UI to observe progress | |
| """ | |
| import asyncio | |
| from collections.abc import AsyncGenerator | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, Any, Literal | |
| import structlog | |
| from agent_framework import ( | |
| MAGENTIC_EVENT_TYPE_ORCHESTRATOR, | |
| AgentRunUpdateEvent, | |
| ChatAgent, | |
| ExecutorCompletedEvent, | |
| MagenticBuilder, | |
| WorkflowOutputEvent, | |
| ) | |
| from src.agents.magentic_agents import ( | |
| create_hypothesis_agent, | |
| create_judge_agent, | |
| create_report_agent, | |
| create_search_agent, | |
| ) | |
| from src.agents.state import get_magentic_state, init_magentic_state | |
| from src.clients.base import BaseChatClient | |
| from src.clients.factory import get_chat_client | |
| from src.config.domain import ResearchDomain, get_domain_config | |
| from src.orchestrators.base import OrchestratorProtocol | |
| from src.utils.config import settings | |
| from src.utils.models import AgentEvent | |
| from src.utils.service_loader import get_embedding_service_if_available | |
| if TYPE_CHECKING: | |
| from src.services.embedding_protocol import EmbeddingServiceProtocol | |
| logger = structlog.get_logger() | |
| # Agent ID constants - prevents silent breakage if agents are renamed | |
| REPORTER_AGENT_ID = "reporter" | |
| SEARCHER_AGENT_ID = "searcher" | |
| JUDGE_AGENT_ID = "judge" | |
| HYPOTHESIZER_AGENT_ID = "hypothesizer" | |
| class WorkflowState: | |
| """Tracks mutable state during workflow execution.""" | |
| iteration: int = 0 | |
| reporter_ran: bool = False | |
| current_message_buffer: str = "" | |
| current_agent_id: str | None = None | |
| last_streamed_length: int = 0 | |
| final_event_received: bool = False | |
| class AdvancedOrchestrator(OrchestratorProtocol): | |
| """ | |
| Advanced orchestrator using Microsoft Agent Framework ChatAgent pattern. | |
| Each agent has an internal LLM that understands natural language | |
| instructions from the manager and can call tools appropriately. | |
| This orchestrator provides: | |
| - Multi-agent coordination (Search, Hypothesis, Judge, Report) | |
| - Manager agent for workflow orchestration | |
| - Streaming events for real-time UI updates | |
| - Configurable timeouts and round limits | |
| """ | |
| def __init__( | |
| self, | |
| max_rounds: int = 5, | |
| chat_client: BaseChatClient | None = None, | |
| provider: str | None = None, | |
| api_key: str | None = None, | |
| domain: ResearchDomain | str | None = None, | |
| timeout_seconds: float | None = None, | |
| ) -> None: | |
| """Initialize the advanced orchestrator. | |
| Args: | |
| max_rounds: Maximum number of coordination rounds. | |
| chat_client: Optional pre-configured chat client. | |
| provider: Optional provider override ("openai", "huggingface"). | |
| api_key: Optional API key override. | |
| domain: Research domain for customization. | |
| timeout_seconds: Optional timeout override (defaults to settings). | |
| """ | |
| self._max_rounds = max_rounds | |
| self.domain = domain or ResearchDomain.SEXUAL_HEALTH | |
| self.domain_config = get_domain_config(self.domain) | |
| self._timeout_seconds = timeout_seconds or settings.advanced_timeout | |
| self.logger = logger.bind(orchestrator="advanced") | |
| # Use provided client or create one via factory | |
| self._chat_client = chat_client or get_chat_client( | |
| provider=provider, | |
| api_key=api_key, | |
| ) | |
| # Store API key for service initialization | |
| self._api_key = api_key | |
| # Event stream for UI updates | |
| self._events: list[AgentEvent] = [] | |
| # Initialize services lazily | |
| self._embedding_service: EmbeddingServiceProtocol | None = None | |
| # Track execution statistics | |
| self.stats = { | |
| "rounds": 0, | |
| "searches": 0, | |
| "hypotheses": 0, | |
| "reports": 0, | |
| "errors": 0, | |
| } | |
| def _init_embedding_service(self) -> "EmbeddingServiceProtocol | None": | |
| """Initialize embedding service if available.""" | |
| return get_embedding_service_if_available(api_key=self._api_key) | |
| def _build_workflow(self) -> Any: | |
| """Build the workflow with ChatAgent participants.""" | |
| # Create agents with internal LLMs | |
| search_agent = create_search_agent(self._chat_client, domain=self.domain) | |
| judge_agent = create_judge_agent(self._chat_client, domain=self.domain) | |
| hypothesis_agent = create_hypothesis_agent(self._chat_client, domain=self.domain) | |
| report_agent = create_report_agent(self._chat_client, domain=self.domain) | |
| # Manager chat client (orchestrates the agents) | |
| manager_client = self._chat_client | |
| manager_agent = ChatAgent(chat_client=manager_client) | |
| return ( | |
| MagenticBuilder() | |
| .participants( | |
| searcher=search_agent, | |
| hypothesizer=hypothesis_agent, | |
| judge=judge_agent, | |
| reporter=report_agent, | |
| ) | |
| .with_standard_manager( | |
| agent=manager_agent, | |
| max_round_count=self._max_rounds, | |
| max_stall_count=3, | |
| max_reset_count=2, | |
| ) | |
| .build() | |
| ) | |
| def _create_task_prompt(self, query: str) -> str: | |
| """Create the initial task prompt for the manager agent.""" | |
| return f"""Research {self.domain_config.report_focus} for: {query} | |
| ## CRITICAL RULE | |
| When JudgeAgent says "SUFFICIENT EVIDENCE" or "STOP SEARCHING": | |
| β IMMEDIATELY delegate to ReportAgent for synthesis | |
| β Do NOT continue searching or gathering more evidence | |
| β The Judge has determined evidence quality is adequate | |
| ## Standard Workflow | |
| 1. SearchAgent: Find evidence from PubMed, ClinicalTrials.gov, and Europe PMC | |
| 2. HypothesisAgent: Generate mechanistic hypotheses (Drug -> Target -> Pathway -> Effect) | |
| 3. JudgeAgent: Evaluate if evidence is sufficient | |
| 4. If insufficient -> SearchAgent refines search based on gaps | |
| 5. If sufficient -> ReportAgent synthesizes final report | |
| Focus on: | |
| - Identifying specific molecular targets | |
| - Understanding mechanism of action | |
| - Finding clinical evidence supporting hypotheses | |
| The final output should be a structured research report.""" | |
| def _get_agent_semantic_name(self, agent_id: str) -> str: | |
| """Map internal agent ID to user-facing semantic name.""" | |
| name = agent_id.lower() | |
| if SEARCHER_AGENT_ID in name: | |
| return "SearchAgent" | |
| if JUDGE_AGENT_ID in name: | |
| return "JudgeAgent" | |
| if HYPOTHESIZER_AGENT_ID in name: | |
| return "HypothesisAgent" | |
| if REPORTER_AGENT_ID in name: | |
| return "ReportAgent" | |
| return "ManagerAgent" | |
| async def _init_workflow_events(self, query: str) -> AsyncGenerator[AgentEvent, None]: | |
| """Yield initialization events.""" | |
| yield AgentEvent( | |
| type="started", | |
| message=f"Starting research (Advanced mode): {query}", | |
| iteration=0, | |
| ) | |
| yield AgentEvent( | |
| type="progress", | |
| message="Loading embedding service (LlamaIndex/ChromaDB)...", | |
| iteration=0, | |
| ) | |
| async def _synthesize_fallback( | |
| self, | |
| iteration: int, | |
| reason: str, | |
| ) -> AsyncGenerator[AgentEvent, None]: | |
| """ | |
| Unified fallback synthesis for all termination scenarios. | |
| This method handles synthesis when the workflow terminates without | |
| a proper report from ReportAgent. It's a safety net for: | |
| - Timeout scenarios | |
| - Manager model failing to delegate to ReportAgent (7B model limitation) | |
| - Max rounds reached without synthesis | |
| Args: | |
| iteration: Current workflow iteration count | |
| reason: Why synthesis is being forced ("timeout", "no_reporter", "max_rounds") | |
| """ | |
| status_messages = { | |
| "timeout": "Workflow timed out. Synthesizing available evidence...", | |
| "no_reporter": "Synthesizing research findings...", | |
| "max_rounds": "Max rounds reached. Synthesizing findings...", | |
| } | |
| try: | |
| state = get_magentic_state() | |
| evidence_summary = await state.memory.get_context_summary() | |
| report_agent = create_report_agent(self._chat_client, domain=self.domain) | |
| yield AgentEvent( | |
| type="synthesizing", | |
| message=status_messages.get(reason, "Synthesizing..."), | |
| iteration=iteration, | |
| ) | |
| synthesis_result = await report_agent.run( | |
| "Synthesize research report from this evidence. " | |
| f"If evidence is sparse, say so.\n\n{evidence_summary}" | |
| ) | |
| yield AgentEvent( | |
| type="complete", | |
| message=synthesis_result.text, | |
| data={"reason": f"{reason}_synthesis", "iterations": iteration}, | |
| iteration=iteration, | |
| ) | |
| except Exception as synth_error: | |
| logger.error("Fallback synthesis failed", reason=reason, error=str(synth_error)) | |
| yield AgentEvent( | |
| type="complete", | |
| message=f"Research completed. Synthesis failed: {synth_error}", | |
| data={"reason": f"{reason}_synthesis_failed", "iterations": iteration}, | |
| iteration=iteration, | |
| ) | |
| async def run( # noqa: PLR0915 - Complex but necessary for event stream handling | |
| self, | |
| query: str, | |
| ) -> AsyncGenerator[AgentEvent, None]: | |
| """ | |
| Run the workflow. | |
| Args: | |
| query: User's research question | |
| Yields: | |
| AgentEvent objects for real-time UI updates | |
| """ | |
| logger.info("Starting Advanced orchestrator", query=query) | |
| async for event in self._init_workflow_events(query): | |
| yield event | |
| # Initialize context state | |
| embedding_service = self._init_embedding_service() | |
| yield AgentEvent( | |
| type="progress", | |
| message="Initializing research memory...", | |
| iteration=0, | |
| ) | |
| init_magentic_state(query, embedding_service) | |
| yield AgentEvent( | |
| type="progress", | |
| message="Building agent team (Search, Judge, Hypothesis, Report)...", | |
| iteration=0, | |
| ) | |
| workflow = self._build_workflow() | |
| task = self._create_task_prompt(query) | |
| # UX FIX: Yield thinking state before blocking workflow call | |
| # The workflow.run_stream() blocks for 2+ minutes on first LLM call | |
| yield AgentEvent( | |
| type="thinking", | |
| message=( | |
| f"Multi-agent reasoning in progress (Limit: {self._max_rounds} Manager rounds)... " | |
| "Allocating time for deep research..." | |
| ), | |
| iteration=0, | |
| ) | |
| state = WorkflowState() | |
| try: | |
| async with asyncio.timeout(self._timeout_seconds): | |
| async for event in workflow.run_stream(task): | |
| # 1. Handle Streaming (Source of Truth for Content) | |
| if isinstance(event, AgentRunUpdateEvent) and event.data: | |
| author = getattr(event.data, "author_name", None) | |
| # Detect agent switch to clear buffer | |
| if author != state.current_agent_id: | |
| state.current_message_buffer = "" | |
| state.current_agent_id = author | |
| text = getattr(event.data, "text", None) | |
| if text: | |
| state.current_message_buffer += text | |
| yield AgentEvent( | |
| type="streaming", | |
| message=text, | |
| data={"agent_id": author}, | |
| iteration=state.iteration, | |
| ) | |
| continue | |
| # 2. Handle Completion Signal | |
| if isinstance(event, ExecutorCompletedEvent): | |
| state.iteration += 1 | |
| # P1 FIX: Track if ReportAgent produced output | |
| # Note: ExecutorCompletedEvent might not have agent_id directly accessible | |
| # The executor_id usually maps to the agent name | |
| agent_name = getattr(event, "executor_id", "") or "unknown" | |
| if REPORTER_AGENT_ID in agent_name.lower(): | |
| state.reporter_ran = True | |
| comp_event, prog_event = self._handle_completion_event( | |
| event, state.current_message_buffer, state.iteration | |
| ) | |
| yield comp_event | |
| yield prog_event | |
| # P2 BUG FIX: Save length before clearing | |
| state.last_streamed_length = len(state.current_message_buffer) | |
| # Clear buffer after consuming | |
| state.current_message_buffer = "" | |
| continue | |
| # 3. Handle Final Events Inline (P2 Duplicate Report Fix + P1 Forced Synthesis) | |
| if isinstance(event, WorkflowOutputEvent): | |
| if state.final_event_received: | |
| continue # Skip duplicate final events | |
| state.final_event_received = True | |
| # P1 FIX: Force synthesis if ReportAgent never ran | |
| if not state.reporter_ran: | |
| logger.warning( | |
| "ReportAgent never ran - forcing synthesis", | |
| iterations=state.iteration, | |
| ) | |
| async for synth_event in self._synthesize_fallback( | |
| state.iteration, "no_reporter" | |
| ): | |
| yield synth_event | |
| else: | |
| yield self._handle_final_event( | |
| event, state.iteration, state.last_streamed_length | |
| ) | |
| continue | |
| # 4. Handle other events normally | |
| agent_event = self._process_event(event, state.iteration) | |
| if agent_event: | |
| yield agent_event | |
| # GUARANTEE: Always emit termination event if stream ends without one | |
| # (e.g., max rounds reached) | |
| if not state.final_event_received: | |
| logger.warning( | |
| "Workflow ended without final event", | |
| iterations=state.iteration, | |
| ) | |
| # P1 FIX: Force synthesis if ReportAgent never ran | |
| if not state.reporter_ran: | |
| async for synth_event in self._synthesize_fallback( | |
| state.iteration, "max_rounds" | |
| ): | |
| yield synth_event | |
| else: | |
| yield AgentEvent( | |
| type="complete", | |
| message=( | |
| f"Research completed after {state.iteration} agent rounds. " | |
| "Max iterations reached - results may be partial. " | |
| "Try a more specific query for better results." | |
| ), | |
| data={ | |
| "iterations": state.iteration, | |
| "reason": "max_rounds_reached", | |
| }, | |
| iteration=state.iteration, | |
| ) | |
| except TimeoutError: | |
| async for event in self._synthesize_fallback(state.iteration, "timeout"): | |
| yield event | |
| except Exception as e: | |
| logger.error("Workflow failed", error=str(e)) | |
| yield AgentEvent( | |
| type="error", | |
| message=f"Workflow error: {e!s}", | |
| iteration=state.iteration, | |
| ) | |
| def _handle_completion_event( | |
| self, | |
| event: ExecutorCompletedEvent, | |
| buffer: str, | |
| iteration: int, | |
| ) -> tuple[AgentEvent, AgentEvent]: | |
| """Handle an agent completion event using the accumulated buffer.""" | |
| # Use buffer if available, otherwise fall back cautiously | |
| # (Only fall back if buffer empty, which implies tool-only turn) | |
| text_content = buffer | |
| if not text_content: | |
| # ExecutorCompletedEvent doesn't carry the message directly in the same way | |
| # Try extraction but ignore repr strings AND empty strings | |
| # The result is often in event.result or similar, but buffering is safer | |
| text_content = "Action completed (Tool Call)" | |
| agent_id = getattr(event, "executor_id", "unknown") or "unknown" | |
| event_type = self._get_event_type_for_agent(agent_id) | |
| semantic_name = self._get_agent_semantic_name(agent_id) | |
| completion_event = AgentEvent( | |
| type=event_type, | |
| message=f"{semantic_name}: {self._smart_truncate(text_content)}", | |
| iteration=iteration, | |
| ) | |
| progress_event = AgentEvent( | |
| type="progress", | |
| message=f"Step {iteration}: {semantic_name} task completed", | |
| iteration=iteration, | |
| ) | |
| return completion_event, progress_event | |
| def _handle_final_event( | |
| self, | |
| event: WorkflowOutputEvent, | |
| iteration: int, | |
| last_streamed_length: int, | |
| ) -> AgentEvent: | |
| """Handle final workflow events with duplicate content suppression (P2 Bug Fix).""" | |
| # DECISION: Did we stream substantial content? | |
| if last_streamed_length > 100: | |
| # YES: Final event is a SIGNAL, not a payload | |
| return AgentEvent( | |
| type="complete", | |
| message="Research complete.", | |
| data={ | |
| "iterations": iteration, | |
| "streamed_chars": last_streamed_length, | |
| }, | |
| iteration=iteration, | |
| ) | |
| # NO: Final event must carry the payload (tool-only turn, cache hit) | |
| text = self._extract_text(event.data) if event.data else "Research complete" | |
| return AgentEvent( | |
| type="complete", | |
| message=text, | |
| data={"iterations": iteration}, | |
| iteration=iteration, | |
| ) | |
| def _extract_text(self, message: Any) -> str: | |
| """ | |
| Defensively extract text from a message object. | |
| Handles ChatMessage objects from both OpenAI and HuggingFace clients. | |
| ChatMessage has: .text (str), .contents (list of content objects) | |
| Also handles plain string messages (e.g., WorkflowOutputEvent.data). | |
| """ | |
| if not message: | |
| return "" | |
| # Priority 0: Handle plain string messages (e.g., WorkflowOutputEvent.data) | |
| if isinstance(message, str): | |
| # Filter out obvious repr-style noise | |
| if not (message.startswith("<") and "object at" in message): | |
| return message | |
| return "" | |
| # Priority 1: .text (standard ChatMessage text content) | |
| if hasattr(message, "text") and message.text: | |
| text = message.text | |
| # Verify it's actually a string, not the object itself | |
| if isinstance(text, str) and not (text.startswith("<") and "object at" in text): | |
| return text | |
| # Priority 2: .contents (list of FunctionCallContent, TextContent, etc.) | |
| # This handles tool call responses from HuggingFace | |
| if hasattr(message, "contents") and message.contents: | |
| parts = [] | |
| for content in message.contents: | |
| # TextContent has .text | |
| if hasattr(content, "text") and content.text: | |
| parts.append(str(content.text)) | |
| # FunctionCallContent has .name and .arguments | |
| elif hasattr(content, "name"): | |
| parts.append(f"[Tool: {content.name}]") | |
| if parts: | |
| return " ".join(parts) | |
| # Priority 3: .content (legacy - some frameworks use singular) | |
| if hasattr(message, "content") and message.content: | |
| content = message.content | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| return " ".join([str(c.text) for c in content if hasattr(c, "text")]) | |
| # Fallback: Return empty string instead of repr | |
| # The repr is useless for display purposes | |
| return "" | |
| def _get_event_type_for_agent( | |
| self, | |
| agent_name: str, | |
| ) -> Literal["search_complete", "judge_complete", "hypothesizing", "synthesizing", "judging"]: | |
| """Map agent name to appropriate event type. | |
| Args: | |
| agent_name: The agent ID from the workflow event | |
| Returns: | |
| Event type string matching AgentEvent.type Literal | |
| """ | |
| agent_lower = agent_name.lower() | |
| if SEARCHER_AGENT_ID in agent_lower: | |
| return "search_complete" | |
| if JUDGE_AGENT_ID in agent_lower: | |
| return "judge_complete" | |
| if HYPOTHESIZER_AGENT_ID in agent_lower: | |
| return "hypothesizing" | |
| if REPORTER_AGENT_ID in agent_lower: | |
| return "synthesizing" | |
| return "judging" # Default for unknown agents | |
| def _smart_truncate(self, text: str, max_len: int = 200) -> str: | |
| """Truncate at sentence boundary to avoid cutting words.""" | |
| if len(text) <= max_len: | |
| return text | |
| # Find last sentence boundary before limit | |
| truncated = text[:max_len] | |
| last_period = truncated.rfind(". ") | |
| if last_period > max_len // 2: | |
| return truncated[: last_period + 1] | |
| # Fallback to word boundary | |
| return truncated.rsplit(" ", 1)[0] + "..." | |
| def _process_event(self, event: Any, iteration: int) -> AgentEvent | None: | |
| """Process workflow event into AgentEvent.""" | |
| # Handle orchestrator messages (formerly MagenticOrchestratorMessageEvent) | |
| # We check the event type string directly | |
| if getattr(event, "type", "") == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: | |
| kind = getattr(event, "kind", "") | |
| message = getattr(event, "message", "") | |
| # FILTERING: Skip internal framework bookkeeping | |
| if kind in ("task_ledger", "instruction"): | |
| return None | |
| # TRANSFORMATION: Handle user_task BEFORE text extraction | |
| # (user_task uses static message, doesn't need text content) | |
| if kind == "user_task": | |
| return AgentEvent( | |
| type="progress", | |
| message="Manager assigning research task to agents...", | |
| iteration=iteration, | |
| ) | |
| # For other manager events, extract and validate text | |
| text = self._extract_text(message) | |
| if not text: | |
| return None | |
| # Default fallback for other manager events | |
| return AgentEvent( | |
| type="judging", | |
| message=f"Manager ({kind}): {self._smart_truncate(text)}", | |
| iteration=iteration, | |
| ) | |
| # NOTE: The following event types are handled inline in run() loop and never reach | |
| # this method due to `continue` statements: | |
| # - ExecutorCompletedEvent: Accumulator Pattern | |
| # - AgentRunUpdateEvent: Accumulator Pattern | |
| # - WorkflowOutputEvent: P2 Duplicate Fix via _handle_final_event() | |
| return None | |
| def _create_deprecated_alias() -> type["AdvancedOrchestrator"]: | |
| """Create a deprecated alias that warns on use.""" | |
| import warnings | |
| class MagenticOrchestrator(AdvancedOrchestrator): | |
| """Deprecated alias for AdvancedOrchestrator. | |
| .. deprecated:: 0.1.0 | |
| Use :class:`AdvancedOrchestrator` instead. | |
| """ | |
| def __init__(self, *args: Any, **kwargs: Any) -> None: | |
| """Initialize deprecated MagenticOrchestrator (use AdvancedOrchestrator).""" | |
| warnings.warn( | |
| "MagenticOrchestrator is deprecated, use AdvancedOrchestrator instead. " | |
| "The name 'magentic' was confusing with the 'magentic' PyPI package.", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| super().__init__(*args, **kwargs) | |
| return MagenticOrchestrator | |
| # Backwards compatibility alias with deprecation warning | |
| MagenticOrchestrator = _create_deprecated_alias() | |