File size: 6,435 Bytes
8f51ef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Tool Factory

This module provides a factory for creating and managing different types of tools.
"""

import os
import sys
from typing import Dict, List, Any, Optional, Type
from pathlib import Path

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from tools.general.base_tool_manager import BaseToolManager, tool_registry

# Import model factory if available
try:
    from models.model_factory import get_model
except ImportError:
    def get_model(model_name: str):
        """Fallback model factory."""
        return None


class ToolFactory:
    """
    Factory for creating and managing different types of tools.
    """
    
    def __init__(self):
        self._tool_classes: Dict[str, Type[BaseToolManager]] = {}
        self._register_default_tools()
    
    def register_tool_class(self, name: str, tool_class: Type[BaseToolManager]):
        """Register a tool class."""
        self._tool_classes[name] = tool_class
        print(f"Registered tool class: {name}")
    
    def create_tool(self, tool_name: str, model_name: Optional[str] = None, **kwargs) -> Optional[BaseToolManager]:
        """
        Create a tool instance.
        
        Args:
            tool_name: Name of the tool to create
            model_name: Name of the model to use (if required)
            **kwargs: Additional configuration parameters
            
        Returns:
            Tool instance or None if creation failed
        """
        try:
            # Get tool class
            tool_class = self._tool_classes.get(tool_name)
            if not tool_class:
                print(f"Tool class not found: {tool_name}")
                return None
            
            # Get model if required
            model = None
            if model_name:
                model = get_model(model_name)
                if not model:
                    print(f"Model not found: {model_name}")
                    return None
            
            # Special handling for tools that require model_manager
            if tool_name in ["echo_prime_disease_prediction", "echo_prime_measurements"]:
                if not model_name:
                    model_name = "echo_prime"
                print(f"Getting model for {tool_name}: {model_name}")
                model = get_model(model_name)
                if not model:
                    print(f"Model not found for {tool_name}: {model_name}")
                    return None
                print(f"Model found for {tool_name}: {type(model)}")
                tool = tool_class(model)
            else:
                # Create tool instance
                if model:
                    tool = tool_class(model)
                else:
                    tool = tool_class()
            
            # Register in global registry
            tool_registry.register_tool(tool_name, tool, tool.config)
            
            return tool
            
        except Exception as e:
            print(f"Failed to create tool {tool_name}: {e}")
            return None
    
    def get_tool(self, tool_name: str) -> Optional[BaseToolManager]:
        """Get an existing tool instance."""
        return tool_registry.get_manager(tool_name)
    
    def get_tool_instance(self, tool_name: str) -> Optional[Any]:
        """Get the actual tool instance (BaseTool)."""
        return tool_registry.get_tool(tool_name)
    
    def get_available_tools(self) -> List[str]:
        """Get list of available tool names."""
        return list(self._tool_classes.keys())
    
    def get_ready_tools(self) -> List[str]:
        """Get list of ready tool names."""
        return tool_registry.get_available_tools()
    
    def cleanup_all(self):
        """Clean up all tools."""
        tool_registry.cleanup_all()
    
    def _register_default_tools(self):
        """Register default tool classes."""
        # Register echo tools with lazy loading to avoid circular imports
        self._register_echo_tools()
        
        # Add more tool registrations here as needed
    
    def _register_echo_tools(self):
        """Register echo tools with lazy loading."""
        try:
            # Import and register echo tools
            from tools.echo.echo_tool_managers import (
                EchoDiseasePredictionManager,
                EchoImageVideoGenerationManager,
                EchoMeasurementPredictionManager,
                EchoReportGenerationManager,
                EchoSegmentationManager,
                EchoViewClassificationManager
            )
            
            self.register_tool_class("echo_disease_prediction", EchoDiseasePredictionManager)
            self.register_tool_class("echo_image_video_generation", EchoImageVideoGenerationManager)
            self.register_tool_class("echo_measurement_prediction", EchoMeasurementPredictionManager)
            self.register_tool_class("echo_report_generation", EchoReportGenerationManager)
            self.register_tool_class("echo_segmentation", EchoSegmentationManager)
            self.register_tool_class("echo_view_classification", EchoViewClassificationManager)
            
            print("✅ All echo tools registered successfully")
        except ImportError as e:
            print(f"Failed to register echo tools: {e}")
        except Exception as e:
            print(f"Error registering echo tools: {e}")


# Global tool factory
tool_factory = ToolFactory()


def create_tool(tool_name: str, model_name: Optional[str] = None, **kwargs) -> Optional[BaseToolManager]:
    """Create a tool using the global factory."""
    return tool_factory.create_tool(tool_name, model_name, **kwargs)


def get_tool(tool_name: str) -> Optional[BaseToolManager]:
    """Get a tool using the global factory."""
    return tool_factory.get_tool(tool_name)


def get_tool_instance(tool_name: str) -> Optional[Any]:
    """Get a tool instance using the global factory."""
    return tool_factory.get_tool_instance(tool_name)


def get_available_tools() -> List[str]:
    """Get available tools using the global factory."""
    return tool_factory.get_available_tools()


def get_ready_tools() -> List[str]:
    """Get ready tools using the global factory."""
    return tool_factory.get_ready_tools()


def cleanup_all_tools():
    """Clean up all tools using the global factory."""
    tool_factory.cleanup_all()