Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import ssl | |
| import aiohttp | |
| import asyncio | |
| from agents import function_tool | |
| # from ..workers.baseclass import ResearchAgent, ResearchRunner | |
| # from ..workers.utils.parse_output import create_type_parser | |
| from typing import List, Union, Optional | |
| from bs4 import BeautifulSoup | |
| from dotenv import load_dotenv | |
| from pydantic import BaseModel, Field | |
| from crawl4ai import * | |
| load_dotenv() | |
| CONTENT_LENGTH_LIMIT = 10000 # Trim scraped content to this length to avoid large context / token limit issues | |
| SEARCH_PROVIDER = os.getenv("SEARCH_PROVIDER", "serper").lower() | |
| # ------- DEFINE TYPES ------- | |
| class ScrapeResult(BaseModel): | |
| url: str = Field(description="The URL of the webpage") | |
| text: str = Field(description="The full text content of the webpage") | |
| title: str = Field(description="The title of the webpage") | |
| description: str = Field(description="A short description of the webpage") | |
| class WebpageSnippet(BaseModel): | |
| url: str = Field(description="The URL of the webpage") | |
| title: str = Field(description="The title of the webpage") | |
| description: Optional[str] = Field(description="A short description of the webpage") | |
| class SearchResults(BaseModel): | |
| results_list: List[WebpageSnippet] | |
| # ------- DEFINE TOOL ------- | |
| # Add a module-level variable to store the singleton instance | |
| _serper_client = None | |
| async def web_search(query: str) -> Union[List[ScrapeResult], str]: | |
| """Perform a web search for a given query and get back the URLs along with their titles, descriptions and text contents. | |
| Args: | |
| query: The search query | |
| Returns: | |
| List of ScrapeResult objects which have the following fields: | |
| - url: The URL of the search result | |
| - title: The title of the search result | |
| - description: The description of the search result | |
| - text: The full text content of the search result | |
| """ | |
| # Only use SerperClient if search provider is serper | |
| if SEARCH_PROVIDER == "openai": | |
| # For OpenAI search provider, this function should not be called directly | |
| # The WebSearchTool from the agents module will be used instead | |
| return f"The web_search function is not used when SEARCH_PROVIDER is set to 'openai'. Please check your configuration." | |
| else: | |
| try: | |
| # Lazy initialization of SerperClient | |
| global _serper_client | |
| if _serper_client is None: | |
| _serper_client = SerperClient() | |
| search_results = await _serper_client.search( | |
| query, filter_for_relevance=True, max_results=5 | |
| ) | |
| results = await scrape_urls(search_results) | |
| return results | |
| except Exception as e: | |
| # Return a user-friendly error message | |
| return f"Sorry, I encountered an error while searching: {str(e)}" | |
| # ------- DEFINE AGENT FOR FILTERING SEARCH RESULTS BY RELEVANCE ------- | |
| FILTER_AGENT_INSTRUCTIONS = f""" | |
| You are a search result filter. Your task is to analyze a list of SERP search results and determine which ones are relevant | |
| to the original query based on the link, title and snippet. Return only the relevant results in the specified format. | |
| - Remove any results that refer to entities that have similar names to the queried entity, but are not the same. | |
| - E.g. if the query asks about a company "Amce Inc, acme.com", remove results with "acmesolutions.com" or "acme.net" in the link. | |
| Only output JSON. Follow the JSON schema below. Do not output anything else. I will be parsing this with Pydantic so output valid JSON only: | |
| {SearchResults.model_json_schema()} | |
| """ | |
| # selected_model = fast_model | |
| # | |
| # filter_agent = ResearchAgent( | |
| # name="SearchFilterAgent", | |
| # instructions=FILTER_AGENT_INSTRUCTIONS, | |
| # model=selected_model, | |
| # output_type=SearchResults if model_supports_structured_output(selected_model) else None, | |
| # output_parser=create_type_parser(SearchResults) if not model_supports_structured_output(selected_model) else None | |
| # ) | |
| # ------- DEFINE UNDERLYING TOOL LOGIC ------- | |
| # Create a shared connector | |
| ssl_context = ssl.create_default_context() | |
| ssl_context.check_hostname = False | |
| ssl_context.verify_mode = ssl.CERT_NONE | |
| ssl_context.set_ciphers( | |
| "DEFAULT:@SECLEVEL=1" | |
| ) # Add this line to allow older cipher suites | |
| class SerperClient: | |
| """A client for the Serper API to perform Google searches.""" | |
| def __init__(self, api_key: str = None): | |
| self.api_key = api_key or os.getenv("SERPER_API_KEY") | |
| if not self.api_key: | |
| raise ValueError( | |
| "No API key provided. Set SERPER_API_KEY environment variable." | |
| ) | |
| self.url = "https://google.serper.dev/search" | |
| self.headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"} | |
| async def search( | |
| self, query: str, filter_for_relevance: bool = True, max_results: int = 5 | |
| ) -> List[WebpageSnippet]: | |
| """Perform a Google search using Serper API and fetch basic details for top results. | |
| Args: | |
| query: The search query | |
| num_results: Maximum number of results to return (max 10) | |
| Returns: | |
| Dictionary with search results | |
| """ | |
| connector = aiohttp.TCPConnector(ssl=ssl_context) | |
| async with aiohttp.ClientSession(connector=connector) as session: | |
| async with session.post( | |
| self.url, headers=self.headers, json={"q": query, "autocorrect": False} | |
| ) as response: | |
| response.raise_for_status() | |
| results = await response.json() | |
| results_list = [ | |
| WebpageSnippet( | |
| url=result.get("link", ""), | |
| title=result.get("title", ""), | |
| description=result.get("snippet", ""), | |
| ) | |
| for result in results.get("organic", []) | |
| ] | |
| if not results_list: | |
| return [] | |
| if not filter_for_relevance: | |
| return results_list[:max_results] | |
| # return results_list[:max_results] | |
| return await self._filter_results(results_list, query, max_results=max_results) | |
| async def _filter_results( | |
| self, results: List[WebpageSnippet], query: str, max_results: int = 5 | |
| ) -> List[WebpageSnippet]: | |
| # get rid of pubmed source data | |
| filtered_results = [ | |
| res | |
| for res in results | |
| if "pmc.ncbi.nlm.nih.gov" not in res.url | |
| and "pubmed.ncbi.nlm.nih.gov" not in res.url | |
| ] | |
| # # get rid of unrelated data | |
| # serialized_results = [result.model_dump() if isinstance(result, WebpageSnippet) else result for result in | |
| # filtered_results] | |
| # | |
| # user_prompt = f""" | |
| # Original search query: {query} | |
| # | |
| # Search results to analyze: | |
| # {json.dumps(serialized_results, indent=2)} | |
| # | |
| # Return {max_results} search results or less. | |
| # """ | |
| # | |
| # try: | |
| # result = await ResearchRunner.run(filter_agent, user_prompt) | |
| # output = result.final_output_as(SearchResults) | |
| # return output.results_list | |
| # except Exception as e: | |
| # print("Error filtering urls:", str(e)) | |
| # return filtered_results[:max_results] | |
| async def fetch_url(session, url): | |
| try: | |
| async with session.get(url, timeout=5) as response: | |
| return response.status == 200 | |
| except Exception as e: | |
| print(f"Error accessing {url}: {str(e)}") | |
| return False # 返回 False 表示不可访问 | |
| async def filter_unreachable_urls(results): | |
| async with aiohttp.ClientSession() as session: | |
| tasks = [fetch_url(session, res.url) for res in results] | |
| reachable = await asyncio.gather(*tasks) | |
| return [ | |
| res for res, can_access in zip(results, reachable) if can_access | |
| ] | |
| reachable_results = await filter_unreachable_urls(filtered_results) | |
| # Return the first `max_results` or less if there are not enough reachable results | |
| return reachable_results[:max_results] | |
| async def scrape_urls(items: List[WebpageSnippet]) -> List[ScrapeResult]: | |
| """Fetch text content from provided URLs. | |
| Args: | |
| items: List of SearchEngineResult items to extract content from | |
| Returns: | |
| List of ScrapeResult objects which have the following fields: | |
| - url: The URL of the search result | |
| - title: The title of the search result | |
| - description: The description of the search result | |
| - text: The full text content of the search result | |
| """ | |
| connector = aiohttp.TCPConnector(ssl=ssl_context) | |
| async with aiohttp.ClientSession(connector=connector) as session: | |
| # Create list of tasks for concurrent execution | |
| tasks = [] | |
| for item in items: | |
| if item.url: # Skip empty URLs | |
| tasks.append(fetch_and_process_url(session, item)) | |
| # Execute all tasks concurrently and gather results | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| # Filter out errors and return successful results | |
| return [r for r in results if isinstance(r, ScrapeResult)] | |
| async def fetch_and_process_url( | |
| session: aiohttp.ClientSession, item: WebpageSnippet | |
| ) -> ScrapeResult: | |
| """Helper function to fetch and process a single URL.""" | |
| if not is_valid_url(item.url): | |
| return ScrapeResult( | |
| url=item.url, | |
| title=item.title, | |
| description=item.description, | |
| text=f"Error fetching content: URL contains restricted file extension", | |
| ) | |
| try: | |
| async with session.get(item.url, timeout=8) as response: | |
| if response.status == 200: | |
| content = await response.text() | |
| # Run html_to_text in a thread pool to avoid blocking | |
| text_content = await asyncio.get_event_loop().run_in_executor( | |
| None, html_to_text, content | |
| ) | |
| text_content = text_content[ | |
| :CONTENT_LENGTH_LIMIT | |
| ] # Trim content to avoid exceeding token limit | |
| return ScrapeResult( | |
| url=item.url, | |
| title=item.title, | |
| description=item.description, | |
| text=text_content, | |
| ) | |
| else: | |
| # Instead of raising, return a WebSearchResult with an error message | |
| return ScrapeResult( | |
| url=item.url, | |
| title=item.title, | |
| description=item.description, | |
| text=f"Error fetching content: HTTP {response.status}", | |
| ) | |
| except Exception as e: | |
| # Instead of raising, return a WebSearchResult with an error message | |
| return ScrapeResult( | |
| url=item.url, | |
| title=item.title, | |
| description=item.description, | |
| text=f"Error fetching content: {str(e)}", | |
| ) | |
| def html_to_text(html_content: str) -> str: | |
| """ | |
| Strips out all of the unnecessary elements from the HTML context to prepare it for text extraction / LLM processing. | |
| """ | |
| # Parse the HTML using lxml for speed | |
| soup = BeautifulSoup(html_content, "lxml") | |
| # Extract text from relevant tags | |
| tags_to_extract = ("h1", "h2", "h3", "h4", "h5", "h6", "p", "li", "blockquote") | |
| # Use a generator expression for efficiency | |
| extracted_text = "\n".join( | |
| element.get_text(strip=True) | |
| for element in soup.find_all(tags_to_extract) | |
| if element.get_text(strip=True) | |
| ) | |
| return extracted_text | |
| def is_valid_url(url: str) -> bool: | |
| """Check that a URL does not contain restricted file extensions.""" | |
| if any( | |
| ext in url | |
| for ext in [ | |
| ".pdf", | |
| ".doc", | |
| ".xls", | |
| ".ppt", | |
| ".zip", | |
| ".rar", | |
| ".7z", | |
| ".txt", | |
| ".js", | |
| ".xml", | |
| ".css", | |
| ".png", | |
| ".jpg", | |
| ".jpeg", | |
| ".gif", | |
| ".ico", | |
| ".svg", | |
| ".webp", | |
| ".mp3", | |
| ".mp4", | |
| ".avi", | |
| ".mov", | |
| ".wmv", | |
| ".flv", | |
| ".wma", | |
| ".wav", | |
| ".m4a", | |
| ".m4v", | |
| ".m4b", | |
| ".m4p", | |
| ".m4u", | |
| ] | |
| ): | |
| return False | |
| return True | |
| async def url_to_contents(url): | |
| async with AsyncWebCrawler() as crawler: | |
| result = await crawler.arun( | |
| url=url, | |
| ) | |
| # print(result.markdown) | |
| return result.markdown | |
| async def url_to_fit_contents(res): | |
| str_fit_max = 40000 # 40,000字符通常在10,000token,5个合起来不超过50k | |
| browser_config = BrowserConfig( | |
| headless=True, | |
| verbose=True, | |
| ) | |
| run_config = CrawlerRunConfig( | |
| cache_mode=CacheMode.DISABLED, | |
| markdown_generator=DefaultMarkdownGenerator( | |
| content_filter=PruningContentFilter( | |
| threshold=1.0, threshold_type="fixed", min_word_threshold=0 | |
| ) | |
| ), | |
| # markdown_generator=DefaultMarkdownGenerator( | |
| # content_filter=BM25ContentFilter(user_query="WHEN_WE_FOCUS_BASED_ON_A_USER_QUERY", bm25_threshold=1.0) | |
| # ), | |
| ) | |
| try: | |
| async with AsyncWebCrawler(config=browser_config) as crawler: | |
| # 使用 asyncio.wait_for 来设置超时 | |
| result = await asyncio.wait_for( | |
| crawler.arun(url=res.url, config=run_config), timeout=15 # 设置超时 | |
| ) | |
| print(f"char before filtering {len(result.markdown.raw_markdown)}.") | |
| print(f"char after filtering {len(result.markdown.fit_markdown)}.") | |
| return result.markdown.fit_markdown[ | |
| :str_fit_max | |
| ] # 如果成功,返回结果的前str_fit_max个字符 | |
| except asyncio.TimeoutError: | |
| print(f"Timeout occurred while accessing {res.url}.") # 打印超时信息 | |
| return res.text[:str_fit_max] # 如果发生超时,返回res粗略提取 | |
| except Exception as e: | |
| print(f"Exception occurred: {str(e)}") # 打印其他异常信息 | |
| return res.text[:str_fit_max] # 如果发生其他异常,返回res粗略提取 | |