File size: 6,435 Bytes
8f51ef2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
"""
Tool Factory
This module provides a factory for creating and managing different types of tools.
"""
import os
import sys
from typing import Dict, List, Any, Optional, Type
from pathlib import Path
# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from tools.general.base_tool_manager import BaseToolManager, tool_registry
# Import model factory if available
try:
from models.model_factory import get_model
except ImportError:
def get_model(model_name: str):
"""Fallback model factory."""
return None
class ToolFactory:
"""
Factory for creating and managing different types of tools.
"""
def __init__(self):
self._tool_classes: Dict[str, Type[BaseToolManager]] = {}
self._register_default_tools()
def register_tool_class(self, name: str, tool_class: Type[BaseToolManager]):
"""Register a tool class."""
self._tool_classes[name] = tool_class
print(f"Registered tool class: {name}")
def create_tool(self, tool_name: str, model_name: Optional[str] = None, **kwargs) -> Optional[BaseToolManager]:
"""
Create a tool instance.
Args:
tool_name: Name of the tool to create
model_name: Name of the model to use (if required)
**kwargs: Additional configuration parameters
Returns:
Tool instance or None if creation failed
"""
try:
# Get tool class
tool_class = self._tool_classes.get(tool_name)
if not tool_class:
print(f"Tool class not found: {tool_name}")
return None
# Get model if required
model = None
if model_name:
model = get_model(model_name)
if not model:
print(f"Model not found: {model_name}")
return None
# Special handling for tools that require model_manager
if tool_name in ["echo_prime_disease_prediction", "echo_prime_measurements"]:
if not model_name:
model_name = "echo_prime"
print(f"Getting model for {tool_name}: {model_name}")
model = get_model(model_name)
if not model:
print(f"Model not found for {tool_name}: {model_name}")
return None
print(f"Model found for {tool_name}: {type(model)}")
tool = tool_class(model)
else:
# Create tool instance
if model:
tool = tool_class(model)
else:
tool = tool_class()
# Register in global registry
tool_registry.register_tool(tool_name, tool, tool.config)
return tool
except Exception as e:
print(f"Failed to create tool {tool_name}: {e}")
return None
def get_tool(self, tool_name: str) -> Optional[BaseToolManager]:
"""Get an existing tool instance."""
return tool_registry.get_manager(tool_name)
def get_tool_instance(self, tool_name: str) -> Optional[Any]:
"""Get the actual tool instance (BaseTool)."""
return tool_registry.get_tool(tool_name)
def get_available_tools(self) -> List[str]:
"""Get list of available tool names."""
return list(self._tool_classes.keys())
def get_ready_tools(self) -> List[str]:
"""Get list of ready tool names."""
return tool_registry.get_available_tools()
def cleanup_all(self):
"""Clean up all tools."""
tool_registry.cleanup_all()
def _register_default_tools(self):
"""Register default tool classes."""
# Register echo tools with lazy loading to avoid circular imports
self._register_echo_tools()
# Add more tool registrations here as needed
def _register_echo_tools(self):
"""Register echo tools with lazy loading."""
try:
# Import and register echo tools
from tools.echo.echo_tool_managers import (
EchoDiseasePredictionManager,
EchoImageVideoGenerationManager,
EchoMeasurementPredictionManager,
EchoReportGenerationManager,
EchoSegmentationManager,
EchoViewClassificationManager
)
self.register_tool_class("echo_disease_prediction", EchoDiseasePredictionManager)
self.register_tool_class("echo_image_video_generation", EchoImageVideoGenerationManager)
self.register_tool_class("echo_measurement_prediction", EchoMeasurementPredictionManager)
self.register_tool_class("echo_report_generation", EchoReportGenerationManager)
self.register_tool_class("echo_segmentation", EchoSegmentationManager)
self.register_tool_class("echo_view_classification", EchoViewClassificationManager)
print("✅ All echo tools registered successfully")
except ImportError as e:
print(f"Failed to register echo tools: {e}")
except Exception as e:
print(f"Error registering echo tools: {e}")
# Global tool factory
tool_factory = ToolFactory()
def create_tool(tool_name: str, model_name: Optional[str] = None, **kwargs) -> Optional[BaseToolManager]:
"""Create a tool using the global factory."""
return tool_factory.create_tool(tool_name, model_name, **kwargs)
def get_tool(tool_name: str) -> Optional[BaseToolManager]:
"""Get a tool using the global factory."""
return tool_factory.get_tool(tool_name)
def get_tool_instance(tool_name: str) -> Optional[Any]:
"""Get a tool instance using the global factory."""
return tool_factory.get_tool_instance(tool_name)
def get_available_tools() -> List[str]:
"""Get available tools using the global factory."""
return tool_factory.get_available_tools()
def get_ready_tools() -> List[str]:
"""Get ready tools using the global factory."""
return tool_factory.get_ready_tools()
def cleanup_all_tools():
"""Clean up all tools using the global factory."""
tool_factory.cleanup_all()
|