Spaces:
Sleeping
Sleeping
| import asyncio | |
| import re | |
| import time | |
| from typing import Dict, List | |
| from dto.bio_document import BaseBioDocument, create_bio_document | |
| from search_service.base_search import BaseSearchService | |
| from bio_requests.rag_request import RagRequest | |
| from utils.bio_logger import bio_logger as logger | |
| from service.query_rewrite import QueryRewriteService | |
| from service.pubmed_api import PubMedApi | |
| from service.pubmed_async_api import PubMedAsyncApi | |
| from config.global_storage import get_model_config | |
| class PubMedSearchService(BaseSearchService): | |
| def __init__(self): | |
| self.query_rewrite_service = QueryRewriteService() | |
| self.model_config = get_model_config() | |
| self.pubmed_topk = self.model_config["recall"]["pubmed_topk"] | |
| self.es_topk = self.model_config["recall"]["es_topk"] | |
| self.data_source = "pubmed" | |
| async def get_query_list(self, rag_request: RagRequest) -> List[Dict]: | |
| """根据RagRequest获取查询列表""" | |
| if rag_request.is_rewrite: | |
| query_list = await self.query_rewrite_service.query_split(rag_request.query) | |
| logger.info(f"length of query_list after query_split: {len(query_list)}") | |
| if len(query_list) == 0: | |
| logger.info("query_list is empty, use query_split_for_simple") | |
| query_list = await self.query_rewrite_service.query_split_for_simple( | |
| rag_request.query | |
| ) | |
| logger.info( | |
| f"length of query_list after query_split_for_simple: {len(query_list)}" | |
| ) | |
| self.pubmed_topk = rag_request.pubmed_topk | |
| self.es_topk = rag_request.pubmed_topk | |
| else: | |
| self.pubmed_topk = rag_request.top_k | |
| self.es_topk = rag_request.top_k | |
| query_list = [ | |
| { | |
| "query_item": rag_request.query, | |
| "search_type": rag_request.search_type, | |
| } | |
| ] | |
| return query_list | |
| async def search(self, rag_request: RagRequest) -> List[BaseBioDocument]: | |
| """异步搜索PubMed数据库""" | |
| if not rag_request.query: | |
| return [] | |
| start_time = time.time() | |
| query_list = await self.get_query_list(rag_request) | |
| # 使用异步并发替代线程池 | |
| articles_id_list = [] | |
| es_articles = [] | |
| try: | |
| # 创建异步任务列表,使用PubMedApi的search_database方法 | |
| async_tasks = [] | |
| for query in query_list: | |
| task = self._search_pubmed_with_sync_api( | |
| query["query_item"], self.pubmed_topk, query["search_type"] | |
| ) | |
| async_tasks.append((query, task)) | |
| # 并发执行所有搜索任务 | |
| results = await asyncio.gather( | |
| *[task for _, task in async_tasks], return_exceptions=True | |
| ) | |
| # 处理结果 | |
| for i, (query, _) in enumerate(async_tasks): | |
| result = results[i] | |
| if isinstance(result, Exception): | |
| logger.error(f"Error in search pubmed: {result}") | |
| else: | |
| articles_id_list.extend(result) | |
| except Exception as e: | |
| logger.error(f"Error in concurrent PubMed search: {e}") | |
| # 获取文章详细信息 | |
| pubmed_docs = await self.fetch_article_details(articles_id_list) | |
| # 合并结果 | |
| all_results = [] | |
| all_results.extend(pubmed_docs) | |
| all_results.extend(es_articles) | |
| logger.info( | |
| f"""Finished searching PubMed, query:{rag_request.query}, | |
| total articles: {len(articles_id_list)}, total time: {time.time() - start_time:.2f}s""" | |
| ) | |
| return all_results | |
| async def _search_pubmed_with_sync_api( | |
| self, query: str, top_k: int, search_type: str | |
| ) -> List[str]: | |
| """ | |
| 使用PubMedApi的search_database方法,但通过异步包装来提升并发效率 | |
| Args: | |
| query: 搜索查询 | |
| top_k: 返回结果数量 | |
| search_type: 搜索类型 | |
| Returns: | |
| 文章ID列表 | |
| """ | |
| try: | |
| # 在线程池中运行同步的search_database方法 | |
| loop = asyncio.get_event_loop() | |
| pubmed_api = PubMedApi() | |
| # 使用run_in_executor来异步执行同步方法 | |
| id_list = await loop.run_in_executor( | |
| None, # 使用默认线程池 | |
| pubmed_api.search_database, | |
| query, | |
| top_k, | |
| search_type, | |
| ) | |
| return id_list | |
| except Exception as e: | |
| logger.error(f"Error in PubMed search for query '{query}': {e}") | |
| raise e | |
| async def fetch_article_details( | |
| self, articles_id_list: List[str] | |
| ) -> List[BaseBioDocument]: | |
| """根据文章ID从pubmed获取文章详细信息""" | |
| if not articles_id_list: | |
| return [] | |
| # 将articles_id_list去重 | |
| articles_id_list = list(set(articles_id_list)) | |
| # 将articles_id_list以group_size个一组切分成不同的列表 | |
| group_size = 80 | |
| articles_id_groups = [ | |
| articles_id_list[i : i + group_size] | |
| for i in range(0, len(articles_id_list), group_size) | |
| ] | |
| try: | |
| # 并发获取所有组的详细信息 | |
| batch_tasks = [] | |
| for ids in articles_id_groups: | |
| pubmed_async_api = PubMedAsyncApi() | |
| task = pubmed_async_api.fetch_details(id_list=ids) | |
| batch_tasks.append(task) | |
| task_results = await asyncio.gather(*batch_tasks, return_exceptions=True) | |
| fetch_results = [] | |
| for result in task_results: | |
| if isinstance(result, Exception): | |
| logger.error(f"Error in fetch_details: {result}") | |
| continue | |
| fetch_results.extend(result) | |
| except Exception as e: | |
| logger.error(f"Error in concurrent fetch_details: {e}") | |
| return [] | |
| # 转换为BioDocument对象 | |
| all_results = [ | |
| create_bio_document( | |
| title=result["title"], | |
| abstract=result["abstract"], | |
| authors=self.process_authors(result["authors"]), | |
| doi=result["doi"], | |
| source=self.data_source, | |
| source_id=result["pmid"], | |
| pub_date=result["pub_date"], | |
| journal=result["journal"], | |
| text=result["abstract"], | |
| url=f'https://pubmed.ncbi.nlm.nih.gov/{result["pmid"]}', | |
| ) | |
| for result in fetch_results | |
| ] | |
| return all_results | |
| def process_authors(self, author_list: List[Dict]) -> str: | |
| """处理作者列表,将其转换为字符串""" | |
| return ", ".join( | |
| [f"{author['forename']} {author['lastname']}" for author in author_list] | |
| ) | |