import asyncio import time from typing import List from service.rerank import RerankService from search_service.base_search import BaseSearchService from utils.bio_logger import bio_logger as logger from dto.bio_document import BaseBioDocument from bio_requests.rag_request import RagRequest class RagService: def __init__(self): self.rerank_service = RerankService() # 确保所有子类都被加载 self.search_services = [ subclass() for subclass in BaseSearchService.get_subclasses() ] logger.info( f"Loaded search services: {[service.__class__.__name__ for service in self.search_services]}" ) async def multi_query(self, rag_request: RagRequest) -> List[BaseBioDocument]: start_time = time.time() batch_search = [ service.filter_search(rag_request=rag_request) for service in self.search_services ] task_result = await asyncio.gather(*batch_search, return_exceptions=True) all_results = [] for result in task_result: if isinstance(result, Exception): logger.error(f"Error in search service: {result}") continue all_results.extend(result) end_search_time = time.time() logger.info( f"Found {len(all_results)} results in total,time used:{end_search_time - start_time:.2f}s" ) if rag_request.is_rerank: logger.info("RerankService: is_rerank is True") reranked_results = await self.rerank_service.rerank( rag_request=rag_request, documents=all_results ) end_rerank_time = time.time() logger.info( f"Reranked {len(reranked_results)} results,time used:{end_rerank_time - end_search_time:.2f}s" ) else: logger.info("RerankService: is_rerank is False, skip rerank") reranked_results = all_results return reranked_results[0 : rag_request.top_k]