nothingworry commited on
Commit
6d531e9
·
1 Parent(s): 31f3625

Multi-Tool Parallel Execution

Browse files
backend/api/services/agent_orchestrator.py CHANGED
@@ -23,6 +23,7 @@ from .llm_client import LLMClient
23
  from ..mcp_clients.mcp_client import MCPClient
24
  from .tool_scoring import ToolScoringService
25
  from ..storage.analytics_store import AnalyticsStore
 
26
  import time
27
 
28
 
@@ -661,126 +662,322 @@ Response:"""
661
  reasoning_trace: List[Dict[str, Any]],
662
  pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse:
663
  """
664
- Execute multiple tools in sequence and synthesize results with LLM.
 
665
  """
 
666
  rag_data = None
667
  web_data = None
668
  admin_data = None
669
  collected_data = []
 
 
670
 
671
- parallel_tasks = {}
672
- rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message)
673
- web_parallel_query = self._first_query_for_tool(steps, "web", req.message)
674
- if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query:
675
- if not pre_fetched_rag:
676
- parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query))
677
- parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query))
678
-
679
- # Execute each step in sequence
680
  for step_info in steps:
681
- tool_name = step_info.get("tool")
682
- step_input = step_info.get("input") or {}
683
- query = step_input.get("query") or req.message
684
 
685
- try:
686
- if tool_name == "rag":
687
- # Reuse pre-fetched RAG if available, otherwise fetch
688
- if pre_fetched_rag and query == rag_parallel_query:
689
- rag_resp = pre_fetched_rag
690
- tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"})
691
- elif parallel_tasks.get("rag") and query == rag_parallel_query:
692
- rag_resp = await parallel_tasks["rag"]
693
- tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"})
694
- else:
695
- rag_resp = await self.mcp.call_rag(req.tenant_id, query)
696
- tool_traces.append({"tool": "rag", "response": rag_resp})
697
- rag_data = rag_resp
698
- reasoning_trace.append({
699
- "step": "tool_execution",
700
- "tool": "rag",
701
- "hit_count": len(self._extract_hits(rag_resp)),
702
- "summary": self._summarize_hits(rag_resp, limit=2)
703
- })
704
- # Extract snippets for prompt
705
- if isinstance(rag_resp, dict):
706
- hits = rag_resp.get("results") or rag_resp.get("hits") or []
707
- for h in hits[:5]:
708
- txt = h.get("text") or h.get("content") or str(h)
709
- collected_data.append(f"[RAG] {txt}")
710
-
711
- elif tool_name == "web":
712
- if parallel_tasks.get("web") and query == web_parallel_query:
713
- web_resp = await parallel_tasks["web"]
714
- tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  else:
716
- web_resp = await self.mcp.call_web(req.tenant_id, query)
717
- tool_traces.append({"tool": "web", "response": web_resp})
718
- web_data = web_resp
719
- reasoning_trace.append({
720
- "step": "tool_execution",
721
- "tool": "web",
722
- "hit_count": len(self._extract_hits(web_resp)),
723
- "summary": self._summarize_hits(web_resp, limit=2)
724
- })
725
- # Extract snippets for prompt
726
- if isinstance(web_resp, dict):
727
- hits = web_resp.get("results") or web_resp.get("items") or []
728
- for h in hits[:5]:
729
- title = h.get("title") or h.get("headline") or ""
730
- snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
731
- url = h.get("url") or h.get("link") or ""
732
- collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}")
733
-
734
- elif tool_name == "admin":
735
- admin_resp = await self.mcp.call_admin(req.tenant_id, query)
736
- tool_traces.append({"tool": "admin", "response": admin_resp})
737
- admin_data = admin_resp
738
- collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}")
739
- reasoning_trace.append({
740
- "step": "tool_execution",
741
- "tool": "admin",
742
- "status": "completed"
743
- })
 
 
 
 
 
 
744
 
745
- elif tool_name == "llm":
746
- # LLM is always last - synthesize all collected data
747
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
748
 
749
- except Exception as e:
750
- tool_traces.append({"tool": tool_name, "error": str(e)})
751
- # Continue with other tools even if one fails
752
  reasoning_trace.append({
753
- "step": "error",
754
- "tool": tool_name,
755
- "error": str(e)
756
  })
757
 
758
- # Build comprehensive prompt with all collected data
759
- data_section = "\n---\n".join(collected_data) if collected_data else ""
760
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
761
  if data_section:
762
  prompt = (
763
  f"You are an assistant helping tenant {req.tenant_id}.\n\n"
764
- f"The following information has been gathered from multiple sources:\n\n"
 
765
  f"{data_section}\n\n"
766
- f"User question: {req.message}\n\n"
767
- f"Provide a comprehensive, accurate answer using the information above. "
768
- f"Cite sources where appropriate (RAG for internal docs, WEB for online sources)."
 
 
 
 
769
  )
 
770
  else:
771
  # No data collected, just answer the question
772
  prompt = req.message
773
 
774
  # Final LLM synthesis
775
  try:
 
776
  llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
  return AgentResponse(
778
  text=llm_out,
779
  decision=decision,
780
  tool_traces=tool_traces,
781
  reasoning_trace=reasoning_trace + [{
782
  "step": "llm_response",
783
- "mode": "multi_step"
 
 
784
  }]
785
  )
786
  except Exception as e:
@@ -826,10 +1023,21 @@ Response:"""
826
  snippets.append(f"{title}\n{snippet}\nSource: {url}")
827
 
828
  snippet_text = "\n---\n".join(snippets) or ""
 
 
 
 
829
  prompt = (
830
- f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n"
831
- f"User question: {req.message}\nAnswer succinctly and indicate which results you used."
 
 
 
 
 
 
832
  )
 
833
  return prompt
834
 
835
  @staticmethod
@@ -849,6 +1057,30 @@ Response:"""
849
  summaries.append(snippet[:160])
850
  return summaries
851
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  @staticmethod
853
  def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]:
854
  for step in steps:
 
23
  from ..mcp_clients.mcp_client import MCPClient
24
  from .tool_scoring import ToolScoringService
25
  from ..storage.analytics_store import AnalyticsStore
26
+ from .result_merger import merge_parallel_results, format_merged_context_for_prompt
27
  import time
28
 
29
 
 
662
  reasoning_trace: List[Dict[str, Any]],
663
  pre_fetched_rag: Optional[Dict[str, Any]] = None) -> AgentResponse:
664
  """
665
+ Execute multiple tools in sequence or parallel and synthesize results with LLM.
666
+ Supports parallel execution when steps are marked with "parallel" flag.
667
  """
668
+ start_time = time.time()
669
  rag_data = None
670
  web_data = None
671
  admin_data = None
672
  collected_data = []
673
+ tools_used = []
674
+ total_tokens = 0
675
 
676
+ # Check if any step has parallel execution flag
677
+ parallel_step = None
 
 
 
 
 
 
 
678
  for step_info in steps:
679
+ if step_info.get("parallel"):
680
+ parallel_step = step_info
681
+ break
682
 
683
+ # Handle parallel execution if detected
684
+ if parallel_step and parallel_step.get("parallel"):
685
+ parallel_config = parallel_step.get("parallel")
686
+ parallel_tasks = {}
687
+ start_time_parallel = time.time()
688
+
689
+ # Prepare parallel tasks
690
+ if "rag" in parallel_config:
691
+ rag_query = parallel_config["rag"]
692
+ if pre_fetched_rag:
693
+ # Use pre-fetched RAG if available - create a simple async function
694
+ async def get_prefetched_rag():
695
+ return pre_fetched_rag
696
+ parallel_tasks["rag"] = get_prefetched_rag()
697
+ else:
698
+ parallel_tasks["rag"] = self.mcp.call_rag(req.tenant_id, rag_query)
699
+
700
+ if "web" in parallel_config:
701
+ web_query = parallel_config["web"]
702
+ parallel_tasks["web"] = self.mcp.call_web(req.tenant_id, web_query)
703
+
704
+ # Execute tools in parallel
705
+ if parallel_tasks:
706
+ reasoning_trace.append({
707
+ "step": "parallel_execution",
708
+ "tools": list(parallel_tasks.keys()),
709
+ "mode": "parallel"
710
+ })
711
+
712
+ parallel_results = await self.run_parallel_tools(parallel_tasks)
713
+ parallel_latency_ms = int((time.time() - start_time_parallel) * 1000)
714
+
715
+ # Process RAG results
716
+ if "rag" in parallel_results:
717
+ rag_result = parallel_results["rag"]
718
+ if isinstance(rag_result, Exception):
719
+ tool_traces.append({"tool": "rag", "error": str(rag_result), "note": "parallel"})
720
+ reasoning_trace.append({
721
+ "step": "tool_execution",
722
+ "tool": "rag",
723
+ "status": "error",
724
+ "error": str(rag_result),
725
+ "latency_ms": parallel_latency_ms
726
+ })
727
+ self.analytics.log_tool_usage(
728
+ tenant_id=req.tenant_id,
729
+ tool_name="rag",
730
+ latency_ms=parallel_latency_ms,
731
+ success=False,
732
+ error_message=str(rag_result)[:200],
733
+ user_id=req.user_id
734
+ )
735
  else:
736
+ rag_data = rag_result
737
+ tools_used.append("rag")
738
+ tool_traces.append({"tool": "rag", "response": rag_result, "note": "parallel"})
739
+ hits_count = len(self._extract_hits(rag_result))
740
+ avg_score = None
741
+ top_score = None
742
+ if hits_count > 0:
743
+ scores = [h.get("score", 0.0) for h in self._extract_hits(rag_result) if isinstance(h, dict) and "score" in h]
744
+ if scores:
745
+ avg_score = sum(scores) / len(scores)
746
+ top_score = max(scores)
747
+ self.analytics.log_rag_search(
748
+ tenant_id=req.tenant_id,
749
+ query=req.message[:500],
750
+ hits_count=hits_count,
751
+ avg_score=avg_score,
752
+ top_score=top_score,
753
+ latency_ms=parallel_latency_ms
754
+ )
755
+ self.analytics.log_tool_usage(
756
+ tenant_id=req.tenant_id,
757
+ tool_name="rag",
758
+ latency_ms=parallel_latency_ms,
759
+ success=True,
760
+ user_id=req.user_id
761
+ )
762
+ reasoning_trace.append({
763
+ "step": "tool_execution",
764
+ "tool": "rag",
765
+ "hit_count": hits_count,
766
+ "summary": self._summarize_hits(rag_result, limit=2),
767
+ "latency_ms": parallel_latency_ms,
768
+ "mode": "parallel"
769
+ })
770
 
771
+ # Process Web results
772
+ if "web" in parallel_results:
773
+ web_result = parallel_results["web"]
774
+ if isinstance(web_result, Exception):
775
+ tool_traces.append({"tool": "web", "error": str(web_result), "note": "parallel"})
776
+ reasoning_trace.append({
777
+ "step": "tool_execution",
778
+ "tool": "web",
779
+ "status": "error",
780
+ "error": str(web_result),
781
+ "latency_ms": parallel_latency_ms
782
+ })
783
+ self.analytics.log_tool_usage(
784
+ tenant_id=req.tenant_id,
785
+ tool_name="web",
786
+ latency_ms=parallel_latency_ms,
787
+ success=False,
788
+ error_message=str(web_result)[:200],
789
+ user_id=req.user_id
790
+ )
791
+ else:
792
+ web_data = web_result
793
+ tools_used.append("web")
794
+ tool_traces.append({"tool": "web", "response": web_result, "note": "parallel"})
795
+ hits_count = len(self._extract_hits(web_result))
796
+ self.analytics.log_tool_usage(
797
+ tenant_id=req.tenant_id,
798
+ tool_name="web",
799
+ latency_ms=parallel_latency_ms,
800
+ success=True,
801
+ user_id=req.user_id
802
+ )
803
+ reasoning_trace.append({
804
+ "step": "tool_execution",
805
+ "tool": "web",
806
+ "hit_count": hits_count,
807
+ "summary": self._summarize_hits(web_result, limit=2),
808
+ "latency_ms": parallel_latency_ms,
809
+ "mode": "parallel"
810
+ })
811
 
812
+ # Merge parallel results
813
+ merged_context = merge_parallel_results(parallel_results)
814
+ sources_list = list(set(e.get("source") for e in merged_context if e.get("source"))) if merged_context else []
815
  reasoning_trace.append({
816
+ "step": "result_merger",
817
+ "merged_items": len(merged_context),
818
+ "sources": sources_list
819
  })
820
 
821
+ # Format merged context for prompt
822
+ data_section = format_merged_context_for_prompt(merged_context, max_items=10)
823
+ else:
824
+ data_section = ""
825
+
826
+ else:
827
+ # Sequential execution (original logic)
828
+ parallel_tasks = {}
829
+ rag_parallel_query = self._first_query_for_tool(steps, "rag", req.message)
830
+ web_parallel_query = self._first_query_for_tool(steps, "web", req.message)
831
+ if rag_parallel_query and web_parallel_query and rag_parallel_query == web_parallel_query:
832
+ if not pre_fetched_rag:
833
+ parallel_tasks["rag"] = asyncio.create_task(self.mcp.call_rag(req.tenant_id, rag_parallel_query))
834
+ parallel_tasks["web"] = asyncio.create_task(self.mcp.call_web(req.tenant_id, web_parallel_query))
835
+
836
+ # Execute each step in sequence
837
+ for step_info in steps:
838
+ tool_name = step_info.get("tool")
839
+ step_input = step_info.get("input") or {}
840
+ query = step_input.get("query") or req.message
841
+
842
+ try:
843
+ if tool_name == "rag":
844
+ # Reuse pre-fetched RAG if available, otherwise fetch
845
+ if pre_fetched_rag and query == rag_parallel_query:
846
+ rag_resp = pre_fetched_rag
847
+ tool_traces.append({"tool": "rag", "response": rag_resp, "note": "used_pre_fetched"})
848
+ elif parallel_tasks.get("rag") and query == rag_parallel_query:
849
+ rag_resp = await parallel_tasks["rag"]
850
+ tool_traces.append({"tool": "rag", "response": rag_resp, "note": "parallel"})
851
+ else:
852
+ rag_resp = await self.mcp.call_rag(req.tenant_id, query)
853
+ tool_traces.append({"tool": "rag", "response": rag_resp})
854
+ rag_data = rag_resp
855
+ tools_used.append("rag")
856
+ reasoning_trace.append({
857
+ "step": "tool_execution",
858
+ "tool": "rag",
859
+ "hit_count": len(self._extract_hits(rag_resp)),
860
+ "summary": self._summarize_hits(rag_resp, limit=2)
861
+ })
862
+ # Extract snippets for prompt
863
+ if isinstance(rag_resp, dict):
864
+ hits = rag_resp.get("results") or rag_resp.get("hits") or []
865
+ for h in hits[:5]:
866
+ txt = h.get("text") or h.get("content") or str(h)
867
+ collected_data.append(f"[RAG] {txt}")
868
+
869
+ elif tool_name == "web":
870
+ if parallel_tasks.get("web") and query == web_parallel_query:
871
+ web_resp = await parallel_tasks["web"]
872
+ tool_traces.append({"tool": "web", "response": web_resp, "note": "parallel"})
873
+ else:
874
+ web_resp = await self.mcp.call_web(req.tenant_id, query)
875
+ tool_traces.append({"tool": "web", "response": web_resp})
876
+ web_data = web_resp
877
+ tools_used.append("web")
878
+ reasoning_trace.append({
879
+ "step": "tool_execution",
880
+ "tool": "web",
881
+ "hit_count": len(self._extract_hits(web_resp)),
882
+ "summary": self._summarize_hits(web_resp, limit=2)
883
+ })
884
+ # Extract snippets for prompt
885
+ if isinstance(web_resp, dict):
886
+ hits = web_resp.get("results") or web_resp.get("items") or []
887
+ for h in hits[:5]:
888
+ title = h.get("title") or h.get("headline") or ""
889
+ snippet = h.get("snippet") or h.get("summary") or h.get("text") or ""
890
+ url = h.get("url") or h.get("link") or ""
891
+ collected_data.append(f"[WEB] {title}\n{snippet}\nSource: {url}")
892
+
893
+ elif tool_name == "admin":
894
+ admin_resp = await self.mcp.call_admin(req.tenant_id, query)
895
+ tool_traces.append({"tool": "admin", "response": admin_resp})
896
+ admin_data = admin_resp
897
+ tools_used.append("admin")
898
+ collected_data.append(f"[ADMIN] {json.dumps(admin_resp)}")
899
+ reasoning_trace.append({
900
+ "step": "tool_execution",
901
+ "tool": "admin",
902
+ "status": "completed"
903
+ })
904
+
905
+ elif tool_name == "llm":
906
+ # LLM is always last - synthesize all collected data
907
+ break
908
+
909
+ except Exception as e:
910
+ tool_traces.append({"tool": tool_name, "error": str(e)})
911
+ # Continue with other tools even if one fails
912
+ reasoning_trace.append({
913
+ "step": "error",
914
+ "tool": tool_name,
915
+ "error": str(e)
916
+ })
917
+
918
+ # Build comprehensive prompt with all collected data
919
+ data_section = "\n---\n".join(collected_data) if collected_data else ""
920
+
921
+ # Build final prompt
922
  if data_section:
923
  prompt = (
924
  f"You are an assistant helping tenant {req.tenant_id}.\n\n"
925
+ f"## Information Collected\n"
926
+ f"The following details have been gathered from multiple reliable sources:\n"
927
  f"{data_section}\n\n"
928
+ f"## User Request\n"
929
+ f"{req.message}\n\n"
930
+ f"## Your Task\n"
931
+ f"Use the information above to directly address the user's request. "
932
+ f"Focus on giving the user exactly what they need—clear guidance, accurate facts, "
933
+ f"and practical steps whenever possible. If the information is incomplete, explain "
934
+ f"what can and cannot be concluded from the available data."
935
  )
936
+
937
  else:
938
  # No data collected, just answer the question
939
  prompt = req.message
940
 
941
  # Final LLM synthesis
942
  try:
943
+ llm_start = time.time()
944
  llm_out = await self.llm.simple_call(prompt, temperature=req.temperature)
945
+ llm_latency_ms = int((time.time() - llm_start) * 1000)
946
+ tools_used.append("llm")
947
+
948
+ estimated_tokens = len(llm_out) // 4 + len(prompt) // 4
949
+ total_tokens += estimated_tokens
950
+
951
+ self.analytics.log_tool_usage(
952
+ tenant_id=req.tenant_id,
953
+ tool_name="llm",
954
+ latency_ms=llm_latency_ms,
955
+ tokens_used=estimated_tokens,
956
+ success=True,
957
+ user_id=req.user_id
958
+ )
959
+
960
+ total_latency_ms = int((time.time() - start_time) * 1000)
961
+ self.analytics.log_agent_query(
962
+ tenant_id=req.tenant_id,
963
+ message_preview=req.message[:200],
964
+ intent="multi_step",
965
+ tools_used=tools_used,
966
+ total_tokens=total_tokens,
967
+ total_latency_ms=total_latency_ms,
968
+ success=True,
969
+ user_id=req.user_id
970
+ )
971
+
972
  return AgentResponse(
973
  text=llm_out,
974
  decision=decision,
975
  tool_traces=tool_traces,
976
  reasoning_trace=reasoning_trace + [{
977
  "step": "llm_response",
978
+ "mode": "multi_step_parallel" if parallel_step else "multi_step",
979
+ "latency_ms": llm_latency_ms,
980
+ "estimated_tokens": estimated_tokens
981
  }]
982
  )
983
  except Exception as e:
 
1023
  snippets.append(f"{title}\n{snippet}\nSource: {url}")
1024
 
1025
  snippet_text = "\n---\n".join(snippets) or ""
1026
+ # prompt = (
1027
+ # f"You are an assistant with access to recent web search results. Use the following results to answer.\n{snippet_text}\n\n"
1028
+ # f"User question: {req.message}\nAnswer succinctly and indicate which results you used."
1029
+ # )
1030
  prompt = (
1031
+ f"You are an assistant with access to recent web search results.\n\n"
1032
+ f"## Search Results\n"
1033
+ f"{snippet_text}\n\n"
1034
+ f"## User Question\n"
1035
+ f"{req.message}\n\n"
1036
+ f"## Your Task\n"
1037
+ f"Provide a clear, accurate, and succinct answer based on the search results above. "
1038
+ f"Indicate which results you used in your reasoning."
1039
  )
1040
+
1041
  return prompt
1042
 
1043
  @staticmethod
 
1057
  summaries.append(snippet[:160])
1058
  return summaries
1059
 
1060
+ async def run_parallel_tools(self, tasks: Dict[str, Any]) -> Dict[str, Any]:
1061
+ """
1062
+ Run multiple tools in parallel using asyncio.gather.
1063
+
1064
+ Args:
1065
+ tasks: Dictionary mapping tool names to coroutines, e.g.:
1066
+ {"rag": rag_coro, "web": web_coro}
1067
+
1068
+ Returns:
1069
+ Dictionary mapping tool names to results, e.g.:
1070
+ {"rag": rag_result, "web": web_result}
1071
+ Exceptions are returned as values if a tool fails.
1072
+ """
1073
+ if not tasks:
1074
+ return {}
1075
+
1076
+ names = list(tasks.keys())
1077
+ coros = [tasks[name] for name in names]
1078
+
1079
+ # Run all coroutines in parallel, return exceptions instead of raising
1080
+ results = await asyncio.gather(*coros, return_exceptions=True)
1081
+
1082
+ return {names[i]: results[i] for i in range(len(names))}
1083
+
1084
  @staticmethod
1085
  def _first_query_for_tool(steps: List[Dict[str, Any]], tool_name: str, default_query: str) -> Optional[str]:
1086
  for step in steps:
backend/api/services/result_merger.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Result Merger Utility
3
+
4
+ Merges and ranks results from parallel tool execution (RAG + Web).
5
+ """
6
+
7
+ from typing import List, Dict, Any, Optional
8
+
9
+
10
+ def merge_parallel_results(results: Dict[str, Any]) -> List[Dict[str, Any]]:
11
+ """
12
+ Merge results from parallel tool execution (RAG + Web).
13
+
14
+ Args:
15
+ results: Dictionary with keys like "rag" and "web" containing tool outputs
16
+
17
+ Returns:
18
+ List of merged context entries, sorted by score (descending)
19
+ """
20
+ final_context = []
21
+
22
+ # Extract RAG results
23
+ if "rag" in results and results["rag"]:
24
+ rag_data = results["rag"]
25
+
26
+ # Handle different RAG response formats
27
+ if isinstance(rag_data, dict):
28
+ hits = rag_data.get("results") or rag_data.get("hits") or []
29
+ elif isinstance(rag_data, list):
30
+ hits = rag_data
31
+ else:
32
+ hits = []
33
+
34
+ for hit in hits:
35
+ if isinstance(hit, dict):
36
+ content = hit.get("text") or hit.get("content") or str(hit)
37
+ score = hit.get("score", 0.0)
38
+ doc_id = hit.get("doc_id") or hit.get("id")
39
+ source = hit.get("source") or hit.get("url") or "internal_doc"
40
+ else:
41
+ content = str(hit)
42
+ score = 0.5 # Default score for non-dict hits
43
+ doc_id = None
44
+ source = "internal_doc"
45
+
46
+ if content:
47
+ final_context.append({
48
+ "source": "internal_policy",
49
+ "text": content,
50
+ "score": float(score),
51
+ "doc_id": doc_id,
52
+ "source_url": source if isinstance(source, str) else None
53
+ })
54
+
55
+ # Extract Web results
56
+ if "web" in results and results["web"]:
57
+ web_data = results["web"]
58
+
59
+ # Handle different Web response formats
60
+ if isinstance(web_data, dict):
61
+ items = web_data.get("results") or web_data.get("items") or []
62
+ elif isinstance(web_data, list):
63
+ items = web_data
64
+ else:
65
+ items = []
66
+
67
+ for item in items:
68
+ if isinstance(item, dict):
69
+ title = item.get("title") or item.get("headline") or ""
70
+ snippet = item.get("snippet") or item.get("summary") or item.get("text") or ""
71
+ url = item.get("url") or item.get("link") or ""
72
+ # Web results get a baseline confidence score
73
+ score = item.get("score", 0.5)
74
+ else:
75
+ title = ""
76
+ snippet = str(item)
77
+ url = ""
78
+ score = 0.5
79
+
80
+ if snippet or title:
81
+ # Combine title and snippet for better context
82
+ text = f"{title}\n{snippet}" if title else snippet
83
+ final_context.append({
84
+ "source": "live_web",
85
+ "text": text,
86
+ "score": float(score),
87
+ "url": url,
88
+ "title": title
89
+ })
90
+
91
+ # Sort by score descending (highest relevance first)
92
+ final_context.sort(key=lambda x: x["score"], reverse=True)
93
+
94
+ return final_context
95
+
96
+
97
+ def format_merged_context_for_prompt(merged_context: List[Dict[str, Any]],
98
+ max_items: int = 10) -> str:
99
+ """
100
+ Format merged context into a readable prompt section.
101
+
102
+ Args:
103
+ merged_context: List of merged context entries from merge_parallel_results
104
+ max_items: Maximum number of items to include
105
+
106
+ Returns:
107
+ Formatted string ready for LLM prompt
108
+ """
109
+ if not merged_context:
110
+ return ""
111
+
112
+ sections = []
113
+ for entry in merged_context[:max_items]:
114
+ source_label = entry.get("source", "unknown")
115
+ text = entry.get("text", "")
116
+ score = entry.get("score", 0.0)
117
+
118
+ # Format based on source type
119
+ if source_label == "internal_policy":
120
+ source_url = entry.get("source_url")
121
+ if source_url:
122
+ sections.append(f"[INTERNAL DOCUMENT - {source_url}]\n{text}")
123
+ else:
124
+ sections.append(f"[INTERNAL DOCUMENT]\n{text}")
125
+ elif source_label == "live_web":
126
+ url = entry.get("url", "")
127
+ title = entry.get("title", "")
128
+ if url:
129
+ sections.append(f"[WEB SOURCE - {url}]\n{title}\n{text}")
130
+ else:
131
+ sections.append(f"[WEB SOURCE]\n{title}\n{text}")
132
+ else:
133
+ sections.append(f"[{source_label.upper()}]\n{text}")
134
+
135
+ return "\n\n---\n\n".join(sections)
136
+
backend/api/services/tool_selector.py CHANGED
@@ -80,10 +80,13 @@ class ToolSelector:
80
  # ---------------------------------
81
  # 6. Use LLM to enhance plan if we have partial steps or complex query
82
  # ---------------------------------
 
 
 
83
  if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0):
84
  plan_prompt = f"""
85
  You are an enterprise MCP agent.
86
- You can select MULTIPLE tools in sequence to provide comprehensive answers.
87
 
88
  TOOLS:
89
  - rag → private knowledge retrieval (use for internal/company docs)
@@ -101,8 +104,22 @@ Determine which tools are needed. You can select:
101
  - Web + LLM (public fact questions)
102
  - RAG + Web + LLM (comprehensive questions needing both sources)
103
 
104
- Return a JSON list describing the steps, e.g.:
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  [
107
  {{"tool": "rag", "reason": "Need internal documentation"}},
108
  {{"tool": "web", "reason": "Need current public information"}},
@@ -125,27 +142,96 @@ Only return the JSON array. Do not include markdown formatting.
125
 
126
  steps_json = json.loads(out)
127
 
128
- # Replace steps with LLM-planned steps (excluding LLM, we'll add it at end)
129
- steps = [
130
- step(s["tool"], {"query": text})
131
- for s in steps_json if s.get("tool") != "llm"
132
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  except Exception as e:
134
- # If LLM planning fails, keep existing steps or use fallback
135
- if not steps:
 
 
 
 
 
 
 
 
136
  steps = []
137
 
138
  # ---------------------------------
139
- # 7. Always end with LLM synthesis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  # ---------------------------------
141
- if not steps or steps[-1]["tool"] != "llm":
142
  steps.append(step("llm", {
143
  "rag_data": rag_results if rag_has_data else None,
144
  "query": text
145
  }))
146
 
147
  # Build reason string showing the tool sequence
148
- tool_names = [s["tool"] for s in steps]
 
 
 
 
 
149
  reason = f"multi-tool plan: {' → '.join(tool_names)} | scores={tool_scores}"
150
 
151
  return _multi_step(steps, reason)
 
80
  # ---------------------------------
81
  # 6. Use LLM to enhance plan if we have partial steps or complex query
82
  # ---------------------------------
83
+ # Check if we should use parallel execution (both RAG and Web needed)
84
+ should_parallel = needs_rag and needs_web and (needs_multiple or rag_score >= 0.55 and web_score >= 0.55)
85
+
86
  if self.llm_client and (needs_multiple or (needs_rag and needs_web) or len(steps) == 0):
87
  plan_prompt = f"""
88
  You are an enterprise MCP agent.
89
+ You can select MULTIPLE tools in sequence OR in parallel to provide comprehensive answers.
90
 
91
  TOOLS:
92
  - rag → private knowledge retrieval (use for internal/company docs)
 
104
  - Web + LLM (public fact questions)
105
  - RAG + Web + LLM (comprehensive questions needing both sources)
106
 
107
+ IMPORTANT: If the query needs BOTH internal docs (RAG) AND current/live info (Web),
108
+ you can mark them for parallel execution by using a "parallel" step.
109
 
110
+ Return a JSON list describing the steps. For parallel execution, use:
111
+ [
112
+ {{
113
+ "parallel": {{
114
+ "rag": "query for internal docs",
115
+ "web": "query for live info"
116
+ }},
117
+ "reason": "Need both internal and live information simultaneously"
118
+ }},
119
+ {{"tool": "llm", "reason": "Synthesize all information"}}
120
+ ]
121
+
122
+ For sequential execution, use:
123
  [
124
  {{"tool": "rag", "reason": "Need internal documentation"}},
125
  {{"tool": "web", "reason": "Need current public information"}},
 
142
 
143
  steps_json = json.loads(out)
144
 
145
+ # Check if LLM returned a parallel step
146
+ has_parallel = any("parallel" in s for s in steps_json)
147
+
148
+ if has_parallel:
149
+ # Extract parallel step and convert to our format
150
+ parallel_step = None
151
+ other_steps = []
152
+ for s in steps_json:
153
+ if "parallel" in s:
154
+ parallel_step = {"parallel": s["parallel"]}
155
+ elif s.get("tool") != "llm":
156
+ other_steps.append(step(s["tool"], {"query": text}))
157
+
158
+ if parallel_step:
159
+ steps = [parallel_step] + other_steps
160
+ else:
161
+ # Fallback: convert to regular steps
162
+ steps = [
163
+ step(s["tool"], {"query": text})
164
+ for s in steps_json if s.get("tool") != "llm"
165
+ ]
166
+ else:
167
+ # Replace steps with LLM-planned steps (excluding LLM, we'll add it at end)
168
+ steps = [
169
+ step(s["tool"], {"query": text})
170
+ for s in steps_json if s.get("tool") != "llm"
171
+ ]
172
  except Exception as e:
173
+ # If LLM planning fails, check if we should create parallel step manually
174
+ if should_parallel and needs_rag and needs_web:
175
+ # Create parallel step manually
176
+ steps = [{
177
+ "parallel": {
178
+ "rag": text,
179
+ "web": text
180
+ }
181
+ }]
182
+ elif not steps:
183
  steps = []
184
 
185
  # ---------------------------------
186
+ # 7. If we have both RAG and Web but no parallel step, consider creating one
187
+ # ---------------------------------
188
+ if should_parallel and needs_rag and needs_web:
189
+ # Check if we already have a parallel step
190
+ has_parallel = any("parallel" in s for s in steps)
191
+ if not has_parallel:
192
+ # Replace sequential RAG and Web steps with a parallel step
193
+ new_steps = []
194
+ rag_query = text
195
+ web_query = text
196
+
197
+ # Extract queries from existing steps if available
198
+ for s in steps:
199
+ if s.get("tool") == "rag":
200
+ rag_query = s.get("input", {}).get("query", text)
201
+ elif s.get("tool") == "web":
202
+ web_query = s.get("input", {}).get("query", text)
203
+
204
+ # Create parallel step
205
+ new_steps.append({
206
+ "parallel": {
207
+ "rag": rag_query,
208
+ "web": web_query
209
+ }
210
+ })
211
+
212
+ # Keep other non-RAG/Web steps
213
+ for s in steps:
214
+ if s.get("tool") not in ["rag", "web"]:
215
+ new_steps.append(s)
216
+
217
+ steps = new_steps
218
+
219
+ # ---------------------------------
220
+ # 8. Always end with LLM synthesis
221
  # ---------------------------------
222
+ if not steps or (isinstance(steps[-1], dict) and steps[-1].get("tool") != "llm" and "parallel" not in steps[-1]):
223
  steps.append(step("llm", {
224
  "rag_data": rag_results if rag_has_data else None,
225
  "query": text
226
  }))
227
 
228
  # Build reason string showing the tool sequence
229
+ tool_names = []
230
+ for s in steps:
231
+ if "parallel" in s:
232
+ tool_names.append("parallel(RAG+Web)")
233
+ elif isinstance(s, dict) and "tool" in s:
234
+ tool_names.append(s["tool"])
235
  reason = f"multi-tool plan: {' → '.join(tool_names)} | scores={tool_scores}"
236
 
237
  return _multi_step(steps, reason)
data/analytics.db CHANGED
Binary files a/data/analytics.db and b/data/analytics.db differ