Echo / tools /echoflow /echoflow_integrated_tool.py
moein99's picture
Initial Echo Space
8f51ef2
#!/usr/bin/env python3
"""
EchoFlow Integrated Tool
This tool integrates EchoFlow into the main echo analysis system.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Type
from pathlib import Path
import tempfile
import shutil
import datetime
import os
import sys
import numpy as np
import cv2
import torch
from pydantic import BaseModel, Field, field_validator
from langchain_core.tools import BaseTool
from langchain_core.callbacks import (
CallbackManagerForToolRun,
AsyncCallbackManagerForToolRun,
)
# Import our fixed EchoFlow implementation
from .echoflow_final_working import EchoFlowFinal
# ----------------------------- Input schema -----------------------------
class EchoFlowGenerationInput(BaseModel):
"""Generate synthetic echo images and videos using EchoFlow."""
views: List[str] = Field(
default_factory=lambda: ["A4C", "PLAX", "PSAX"],
description="Cardiac echo views to synthesize (e.g., A4C, PLAX, PSAX).",
)
ejection_fractions: List[float] = Field(
default_factory=lambda: [0.35, 0.55, 0.70],
description="Ejection fraction values (0.0 to 1.0) used to condition the generation.",
)
num_frames: int = Field(16, ge=1, le=64, description="Number of frames in generated videos.")
timestep: float = Field(0.5, ge=0.0, le=1.0, description="Diffusion timestep for generation.")
outdir: Optional[str] = Field(
None,
description="Root output dir. If omitted, a timestamped folder is created under the tool temp dir.",
)
save_features: bool = Field(True, description="Save generated features as numpy arrays.")
save_metadata: bool = Field(True, description="Save generation metadata.")
@field_validator("views")
@classmethod
def _nonempty_views(cls, v: List[str]) -> List[str]:
if not v:
raise ValueError("At least one view must be provided.")
return v
@field_validator("ejection_fractions")
@classmethod
def _valid_efs(cls, v: List[float]) -> List[float]:
if not v:
raise ValueError("At least one ejection fraction must be provided.")
for x in v:
if x < 0.0 or x > 1.0:
raise ValueError(f"Ejection fraction {x} out of range [0.0, 1.0].")
return v
# ----------------------------- Tool class -------------------------------
class EchoFlowGenerationTool(BaseTool):
"""EchoFlow generation tool integrated with the main echo analysis system."""
name: str = "echoflow_generation"
description: str = (
"Generate synthetic echocardiography images and videos using EchoFlow. "
"Creates realistic echo data for training, testing, and augmentation purposes. "
"Supports multiple views (A4C, PLAX, PSAX) and ejection fraction conditioning."
)
args_schema: Type[BaseModel] = EchoFlowGenerationInput
device: Optional[str] = "cuda"
temp_dir: Path = Path("temp")
echoflow_generator: Optional[EchoFlowFinal] = None
def __init__(self, device: Optional[str] = None, temp_dir: Optional[str] = None):
super().__init__()
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.temp_dir = Path(temp_dir or tempfile.mkdtemp())
self.temp_dir.mkdir(parents=True, exist_ok=True)
# Initialize EchoFlow generator
try:
self.echoflow_generator = EchoFlowFinal(device=self.device)
if not self.echoflow_generator.load_config():
print("⚠️ Could not load EchoFlow config, using defaults")
if not self.echoflow_generator.load_models():
raise RuntimeError("Failed to load EchoFlow models")
print("βœ… EchoFlow generator initialized successfully")
except Exception as e:
print(f"❌ Failed to initialize EchoFlow generator: {e}")
self.echoflow_generator = None
# ----------------------------- helpers -----------------------------
def _ensure_echoflow(self):
if self.echoflow_generator is None:
raise RuntimeError(
"EchoFlow generator not initialized. Check model loading and dependencies."
)
@staticmethod
def _ensure_dirs(root: Path) -> Dict[str, Path]:
d = {
"features": root / "features",
"metadata": root / "metadata",
"masks": root / "masks",
"videos": root / "videos",
}
for p in d.values():
p.mkdir(parents=True, exist_ok=True)
return d
@staticmethod
def _save_numpy(path: Path, arr: np.ndarray) -> str:
np.save(str(path), arr)
return str(path)
@staticmethod
def _save_json(path: Path, data: Dict[str, Any]) -> str:
import json
with open(path, 'w') as f:
json.dump(data, f, indent=2, default=str)
return str(path)
# ----------------------------- core run -----------------------------
def _run(
self,
views: List[str],
ejection_fractions: List[float],
num_frames: int = 16,
timestep: float = 0.5,
outdir: Optional[str] = None,
save_features: bool = True,
save_metadata: bool = True,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Dict[str, Any]:
self._ensure_echoflow()
stamp = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S")
root = Path(outdir) if outdir else (self.temp_dir / f"echoflow_generation_{stamp}")
root.mkdir(parents=True, exist_ok=True)
paths = self._ensure_dirs(root)
run_meta = {
"timestamp_utc": stamp,
"device": self.device,
"views": views,
"ejection_fractions": ejection_fractions,
"num_frames": num_frames,
"timestep": timestep,
}
results: Dict[str, Any] = {
"outdir": str(root),
"meta": run_meta,
"views": {},
"success": True,
"total_generations": 0,
"successful_generations": 0
}
for view in views:
view_rec: Dict[str, Any] = {
"view": view,
"ejection_fractions": {},
"features_saved": [],
"metadata_saved": [],
}
results["views"][view] = view_rec
# Generate for each ejection fraction
for ef in ejection_fractions:
try:
print(f"🎬 Generating {view} view with EF={ef:.2f}")
# Create dummy mask (in real implementation, this would be loaded)
dummy_mask = np.random.randint(0, 255, (400, 400), dtype=np.uint8)
# Generate synthetic echo
result = self.echoflow_generator.generate_synthetic_echo(
mask=dummy_mask,
view_type=view,
ejection_fraction=ef,
num_frames=num_frames
)
if result["success"]:
results["total_generations"] += 1
results["successful_generations"] += 1
# Save features if requested
if save_features:
features_path = paths["features"] / f"{view}_EF{ef:.2f}_features.npy"
self._save_numpy(features_path, result["video_features"])
view_rec["features_saved"].append(str(features_path))
# Save metadata if requested
if save_metadata:
metadata = {
"view": view,
"ejection_fraction": ef,
"num_frames": num_frames,
"timestep": timestep,
"video_features_shape": result["video_features"].shape,
"mask_processed_shape": result["mask_processed"].shape,
"timestamp": result["timestamp"],
"device": result["device"]
}
metadata_path = paths["metadata"] / f"{view}_EF{ef:.2f}_metadata.json"
self._save_json(metadata_path, metadata)
view_rec["metadata_saved"].append(str(metadata_path))
# Save mask if requested
mask_path = paths["masks"] / f"{view}_EF{ef:.2f}_mask.npy"
self._save_numpy(mask_path, result["mask_processed"])
view_rec["ejection_fractions"][f"EF_{ef:.2f}"] = {
"success": True,
"video_features_shape": result["video_features"].shape,
"features_path": str(features_path) if save_features else None,
"metadata_path": str(metadata_path) if save_metadata else None,
"mask_path": str(mask_path)
}
print(f"βœ… {view} EF={ef:.2f} generated successfully")
else:
results["total_generations"] += 1
view_rec["ejection_fractions"][f"EF_{ef:.2f}"] = {
"success": False,
"error": result.get("error", "Unknown error")
}
print(f"❌ {view} EF={ef:.2f} generation failed: {result.get('error', 'Unknown error')}")
except Exception as e:
results["total_generations"] += 1
view_rec["ejection_fractions"][f"EF_{ef:.2f}"] = {
"success": False,
"error": str(e)
}
print(f"❌ {view} EF={ef:.2f} generation error: {e}")
# Calculate success rate
if results["total_generations"] > 0:
results["success_rate"] = results["successful_generations"] / results["total_generations"]
else:
results["success_rate"] = 0.0
print(f"\nπŸ“Š Generation Summary:")
print(f" Total generations: {results['total_generations']}")
print(f" Successful: {results['successful_generations']}")
print(f" Success rate: {results['success_rate']:.2%}")
return results
async def _arun( # pragma: no cover
self,
views: List[str],
ejection_fractions: List[float],
num_frames: int = 16,
timestep: float = 0.5,
outdir: Optional[str] = None,
save_features: bool = True,
save_metadata: bool = True,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Dict[str, Any]:
return self._run(
views=views,
ejection_fractions=ejection_fractions,
num_frames=num_frames,
timestep=timestep,
outdir=outdir,
save_features=save_features,
save_metadata=save_metadata,
)
# ----------------------------- Integration functions -----------------------------
def create_echoflow_tool(device: Optional[str] = None, temp_dir: Optional[str] = None) -> EchoFlowGenerationTool:
"""
Create an EchoFlow generation tool.
Args:
device: Device to use ('cuda', 'cpu', or None for auto-detection)
temp_dir: Temporary directory for outputs
Returns:
Initialized EchoFlowGenerationTool instance
"""
return EchoFlowGenerationTool(device=device, temp_dir=temp_dir)
def test_echoflow_tool():
"""Test the EchoFlow tool."""
print("πŸ§ͺ Testing EchoFlow Tool")
print("=" * 40)
try:
# Create tool
tool = create_echoflow_tool()
# Test generation
output_dir = Path(__file__).resolve().parents[2] / "temp" / "echoflow_test_output"
result = tool.run({
"views": ["A4C", "PLAX"],
"ejection_fractions": [0.35, 0.65],
"num_frames": 8,
"timestep": 0.5,
"outdir": str(output_dir),
"save_features": True,
"save_metadata": True
})
print(f"\nπŸ“Š Test Results:")
print(f" Success: {result['success']}")
print(f" Success rate: {result['success_rate']:.2%}")
print(f" Output directory: {result['outdir']}")
return result
except Exception as e:
print(f"❌ EchoFlow tool test failed: {e}")
import traceback
traceback.print_exc()
return None
if __name__ == "__main__":
# Run tool test
test_echoflow_tool()