|
|
""" |
|
|
Tool Factory |
|
|
|
|
|
This module provides a factory for creating and managing different types of tools. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
from typing import Dict, List, Any, Optional, Type |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
|
|
|
|
from tools.general.base_tool_manager import BaseToolManager, tool_registry |
|
|
|
|
|
|
|
|
try: |
|
|
from models.model_factory import get_model |
|
|
except ImportError: |
|
|
def get_model(model_name: str): |
|
|
"""Fallback model factory.""" |
|
|
return None |
|
|
|
|
|
|
|
|
class ToolFactory: |
|
|
""" |
|
|
Factory for creating and managing different types of tools. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self._tool_classes: Dict[str, Type[BaseToolManager]] = {} |
|
|
self._register_default_tools() |
|
|
|
|
|
def register_tool_class(self, name: str, tool_class: Type[BaseToolManager]): |
|
|
"""Register a tool class.""" |
|
|
self._tool_classes[name] = tool_class |
|
|
print(f"Registered tool class: {name}") |
|
|
|
|
|
def create_tool(self, tool_name: str, model_name: Optional[str] = None, **kwargs) -> Optional[BaseToolManager]: |
|
|
""" |
|
|
Create a tool instance. |
|
|
|
|
|
Args: |
|
|
tool_name: Name of the tool to create |
|
|
model_name: Name of the model to use (if required) |
|
|
**kwargs: Additional configuration parameters |
|
|
|
|
|
Returns: |
|
|
Tool instance or None if creation failed |
|
|
""" |
|
|
try: |
|
|
|
|
|
tool_class = self._tool_classes.get(tool_name) |
|
|
if not tool_class: |
|
|
print(f"Tool class not found: {tool_name}") |
|
|
return None |
|
|
|
|
|
|
|
|
model = None |
|
|
if model_name: |
|
|
model = get_model(model_name) |
|
|
if not model: |
|
|
print(f"Model not found: {model_name}") |
|
|
return None |
|
|
|
|
|
|
|
|
if tool_name in ["echo_prime_disease_prediction", "echo_prime_measurements"]: |
|
|
if not model_name: |
|
|
model_name = "echo_prime" |
|
|
print(f"Getting model for {tool_name}: {model_name}") |
|
|
model = get_model(model_name) |
|
|
if not model: |
|
|
print(f"Model not found for {tool_name}: {model_name}") |
|
|
return None |
|
|
print(f"Model found for {tool_name}: {type(model)}") |
|
|
tool = tool_class(model) |
|
|
else: |
|
|
|
|
|
if model: |
|
|
tool = tool_class(model) |
|
|
else: |
|
|
tool = tool_class() |
|
|
|
|
|
|
|
|
tool_registry.register_tool(tool_name, tool, tool.config) |
|
|
|
|
|
return tool |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Failed to create tool {tool_name}: {e}") |
|
|
return None |
|
|
|
|
|
def get_tool(self, tool_name: str) -> Optional[BaseToolManager]: |
|
|
"""Get an existing tool instance.""" |
|
|
return tool_registry.get_manager(tool_name) |
|
|
|
|
|
def get_tool_instance(self, tool_name: str) -> Optional[Any]: |
|
|
"""Get the actual tool instance (BaseTool).""" |
|
|
return tool_registry.get_tool(tool_name) |
|
|
|
|
|
def get_available_tools(self) -> List[str]: |
|
|
"""Get list of available tool names.""" |
|
|
return list(self._tool_classes.keys()) |
|
|
|
|
|
def get_ready_tools(self) -> List[str]: |
|
|
"""Get list of ready tool names.""" |
|
|
return tool_registry.get_available_tools() |
|
|
|
|
|
def cleanup_all(self): |
|
|
"""Clean up all tools.""" |
|
|
tool_registry.cleanup_all() |
|
|
|
|
|
def _register_default_tools(self): |
|
|
"""Register default tool classes.""" |
|
|
|
|
|
self._register_echo_tools() |
|
|
|
|
|
|
|
|
|
|
|
def _register_echo_tools(self): |
|
|
"""Register echo tools with lazy loading.""" |
|
|
try: |
|
|
|
|
|
from tools.echo.echo_tool_managers import ( |
|
|
EchoDiseasePredictionManager, |
|
|
EchoImageVideoGenerationManager, |
|
|
EchoMeasurementPredictionManager, |
|
|
EchoReportGenerationManager, |
|
|
EchoSegmentationManager, |
|
|
EchoViewClassificationManager |
|
|
) |
|
|
|
|
|
self.register_tool_class("echo_disease_prediction", EchoDiseasePredictionManager) |
|
|
self.register_tool_class("echo_image_video_generation", EchoImageVideoGenerationManager) |
|
|
self.register_tool_class("echo_measurement_prediction", EchoMeasurementPredictionManager) |
|
|
self.register_tool_class("echo_report_generation", EchoReportGenerationManager) |
|
|
self.register_tool_class("echo_segmentation", EchoSegmentationManager) |
|
|
self.register_tool_class("echo_view_classification", EchoViewClassificationManager) |
|
|
|
|
|
print("✅ All echo tools registered successfully") |
|
|
except ImportError as e: |
|
|
print(f"Failed to register echo tools: {e}") |
|
|
except Exception as e: |
|
|
print(f"Error registering echo tools: {e}") |
|
|
|
|
|
|
|
|
|
|
|
tool_factory = ToolFactory() |
|
|
|
|
|
|
|
|
def create_tool(tool_name: str, model_name: Optional[str] = None, **kwargs) -> Optional[BaseToolManager]: |
|
|
"""Create a tool using the global factory.""" |
|
|
return tool_factory.create_tool(tool_name, model_name, **kwargs) |
|
|
|
|
|
|
|
|
def get_tool(tool_name: str) -> Optional[BaseToolManager]: |
|
|
"""Get a tool using the global factory.""" |
|
|
return tool_factory.get_tool(tool_name) |
|
|
|
|
|
|
|
|
def get_tool_instance(tool_name: str) -> Optional[Any]: |
|
|
"""Get a tool instance using the global factory.""" |
|
|
return tool_factory.get_tool_instance(tool_name) |
|
|
|
|
|
|
|
|
def get_available_tools() -> List[str]: |
|
|
"""Get available tools using the global factory.""" |
|
|
return tool_factory.get_available_tools() |
|
|
|
|
|
|
|
|
def get_ready_tools() -> List[str]: |
|
|
"""Get ready tools using the global factory.""" |
|
|
return tool_factory.get_ready_tools() |
|
|
|
|
|
|
|
|
def cleanup_all_tools(): |
|
|
"""Clean up all tools using the global factory.""" |
|
|
tool_factory.cleanup_all() |
|
|
|