|
|
|
|
|
""" |
|
|
EchoFlow Integrated Tool |
|
|
|
|
|
This tool integrates EchoFlow into the main echo analysis system. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from typing import Any, Dict, List, Optional, Type |
|
|
from pathlib import Path |
|
|
import tempfile |
|
|
import shutil |
|
|
import datetime |
|
|
import os |
|
|
import sys |
|
|
|
|
|
import numpy as np |
|
|
import cv2 |
|
|
import torch |
|
|
from pydantic import BaseModel, Field, field_validator |
|
|
from langchain_core.tools import BaseTool |
|
|
from langchain_core.callbacks import ( |
|
|
CallbackManagerForToolRun, |
|
|
AsyncCallbackManagerForToolRun, |
|
|
) |
|
|
|
|
|
|
|
|
from .echoflow_final_working import EchoFlowFinal |
|
|
|
|
|
|
|
|
|
|
|
class EchoFlowGenerationInput(BaseModel): |
|
|
"""Generate synthetic echo images and videos using EchoFlow.""" |
|
|
|
|
|
views: List[str] = Field( |
|
|
default_factory=lambda: ["A4C", "PLAX", "PSAX"], |
|
|
description="Cardiac echo views to synthesize (e.g., A4C, PLAX, PSAX).", |
|
|
) |
|
|
ejection_fractions: List[float] = Field( |
|
|
default_factory=lambda: [0.35, 0.55, 0.70], |
|
|
description="Ejection fraction values (0.0 to 1.0) used to condition the generation.", |
|
|
) |
|
|
num_frames: int = Field(16, ge=1, le=64, description="Number of frames in generated videos.") |
|
|
timestep: float = Field(0.5, ge=0.0, le=1.0, description="Diffusion timestep for generation.") |
|
|
|
|
|
outdir: Optional[str] = Field( |
|
|
None, |
|
|
description="Root output dir. If omitted, a timestamped folder is created under the tool temp dir.", |
|
|
) |
|
|
save_features: bool = Field(True, description="Save generated features as numpy arrays.") |
|
|
save_metadata: bool = Field(True, description="Save generation metadata.") |
|
|
|
|
|
@field_validator("views") |
|
|
@classmethod |
|
|
def _nonempty_views(cls, v: List[str]) -> List[str]: |
|
|
if not v: |
|
|
raise ValueError("At least one view must be provided.") |
|
|
return v |
|
|
|
|
|
@field_validator("ejection_fractions") |
|
|
@classmethod |
|
|
def _valid_efs(cls, v: List[float]) -> List[float]: |
|
|
if not v: |
|
|
raise ValueError("At least one ejection fraction must be provided.") |
|
|
for x in v: |
|
|
if x < 0.0 or x > 1.0: |
|
|
raise ValueError(f"Ejection fraction {x} out of range [0.0, 1.0].") |
|
|
return v |
|
|
|
|
|
|
|
|
|
|
|
class EchoFlowGenerationTool(BaseTool): |
|
|
"""EchoFlow generation tool integrated with the main echo analysis system.""" |
|
|
|
|
|
name: str = "echoflow_generation" |
|
|
description: str = ( |
|
|
"Generate synthetic echocardiography images and videos using EchoFlow. " |
|
|
"Creates realistic echo data for training, testing, and augmentation purposes. " |
|
|
"Supports multiple views (A4C, PLAX, PSAX) and ejection fraction conditioning." |
|
|
) |
|
|
args_schema: Type[BaseModel] = EchoFlowGenerationInput |
|
|
|
|
|
device: Optional[str] = "cuda" |
|
|
temp_dir: Path = Path("temp") |
|
|
echoflow_generator: Optional[EchoFlowFinal] = None |
|
|
|
|
|
def __init__(self, device: Optional[str] = None, temp_dir: Optional[str] = None): |
|
|
super().__init__() |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.temp_dir = Path(temp_dir or tempfile.mkdtemp()) |
|
|
self.temp_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
try: |
|
|
self.echoflow_generator = EchoFlowFinal(device=self.device) |
|
|
if not self.echoflow_generator.load_config(): |
|
|
print("β οΈ Could not load EchoFlow config, using defaults") |
|
|
if not self.echoflow_generator.load_models(): |
|
|
raise RuntimeError("Failed to load EchoFlow models") |
|
|
print("β
EchoFlow generator initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"β Failed to initialize EchoFlow generator: {e}") |
|
|
self.echoflow_generator = None |
|
|
|
|
|
|
|
|
|
|
|
def _ensure_echoflow(self): |
|
|
if self.echoflow_generator is None: |
|
|
raise RuntimeError( |
|
|
"EchoFlow generator not initialized. Check model loading and dependencies." |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def _ensure_dirs(root: Path) -> Dict[str, Path]: |
|
|
d = { |
|
|
"features": root / "features", |
|
|
"metadata": root / "metadata", |
|
|
"masks": root / "masks", |
|
|
"videos": root / "videos", |
|
|
} |
|
|
for p in d.values(): |
|
|
p.mkdir(parents=True, exist_ok=True) |
|
|
return d |
|
|
|
|
|
@staticmethod |
|
|
def _save_numpy(path: Path, arr: np.ndarray) -> str: |
|
|
np.save(str(path), arr) |
|
|
return str(path) |
|
|
|
|
|
@staticmethod |
|
|
def _save_json(path: Path, data: Dict[str, Any]) -> str: |
|
|
import json |
|
|
with open(path, 'w') as f: |
|
|
json.dump(data, f, indent=2, default=str) |
|
|
return str(path) |
|
|
|
|
|
|
|
|
|
|
|
def _run( |
|
|
self, |
|
|
views: List[str], |
|
|
ejection_fractions: List[float], |
|
|
num_frames: int = 16, |
|
|
timestep: float = 0.5, |
|
|
outdir: Optional[str] = None, |
|
|
save_features: bool = True, |
|
|
save_metadata: bool = True, |
|
|
run_manager: Optional[CallbackManagerForToolRun] = None, |
|
|
) -> Dict[str, Any]: |
|
|
self._ensure_echoflow() |
|
|
|
|
|
stamp = datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S") |
|
|
root = Path(outdir) if outdir else (self.temp_dir / f"echoflow_generation_{stamp}") |
|
|
root.mkdir(parents=True, exist_ok=True) |
|
|
paths = self._ensure_dirs(root) |
|
|
|
|
|
run_meta = { |
|
|
"timestamp_utc": stamp, |
|
|
"device": self.device, |
|
|
"views": views, |
|
|
"ejection_fractions": ejection_fractions, |
|
|
"num_frames": num_frames, |
|
|
"timestep": timestep, |
|
|
} |
|
|
|
|
|
results: Dict[str, Any] = { |
|
|
"outdir": str(root), |
|
|
"meta": run_meta, |
|
|
"views": {}, |
|
|
"success": True, |
|
|
"total_generations": 0, |
|
|
"successful_generations": 0 |
|
|
} |
|
|
|
|
|
for view in views: |
|
|
view_rec: Dict[str, Any] = { |
|
|
"view": view, |
|
|
"ejection_fractions": {}, |
|
|
"features_saved": [], |
|
|
"metadata_saved": [], |
|
|
} |
|
|
results["views"][view] = view_rec |
|
|
|
|
|
|
|
|
for ef in ejection_fractions: |
|
|
try: |
|
|
print(f"π¬ Generating {view} view with EF={ef:.2f}") |
|
|
|
|
|
|
|
|
dummy_mask = np.random.randint(0, 255, (400, 400), dtype=np.uint8) |
|
|
|
|
|
|
|
|
result = self.echoflow_generator.generate_synthetic_echo( |
|
|
mask=dummy_mask, |
|
|
view_type=view, |
|
|
ejection_fraction=ef, |
|
|
num_frames=num_frames |
|
|
) |
|
|
|
|
|
if result["success"]: |
|
|
results["total_generations"] += 1 |
|
|
results["successful_generations"] += 1 |
|
|
|
|
|
|
|
|
if save_features: |
|
|
features_path = paths["features"] / f"{view}_EF{ef:.2f}_features.npy" |
|
|
self._save_numpy(features_path, result["video_features"]) |
|
|
view_rec["features_saved"].append(str(features_path)) |
|
|
|
|
|
|
|
|
if save_metadata: |
|
|
metadata = { |
|
|
"view": view, |
|
|
"ejection_fraction": ef, |
|
|
"num_frames": num_frames, |
|
|
"timestep": timestep, |
|
|
"video_features_shape": result["video_features"].shape, |
|
|
"mask_processed_shape": result["mask_processed"].shape, |
|
|
"timestamp": result["timestamp"], |
|
|
"device": result["device"] |
|
|
} |
|
|
metadata_path = paths["metadata"] / f"{view}_EF{ef:.2f}_metadata.json" |
|
|
self._save_json(metadata_path, metadata) |
|
|
view_rec["metadata_saved"].append(str(metadata_path)) |
|
|
|
|
|
|
|
|
mask_path = paths["masks"] / f"{view}_EF{ef:.2f}_mask.npy" |
|
|
self._save_numpy(mask_path, result["mask_processed"]) |
|
|
|
|
|
view_rec["ejection_fractions"][f"EF_{ef:.2f}"] = { |
|
|
"success": True, |
|
|
"video_features_shape": result["video_features"].shape, |
|
|
"features_path": str(features_path) if save_features else None, |
|
|
"metadata_path": str(metadata_path) if save_metadata else None, |
|
|
"mask_path": str(mask_path) |
|
|
} |
|
|
|
|
|
print(f"β
{view} EF={ef:.2f} generated successfully") |
|
|
|
|
|
else: |
|
|
results["total_generations"] += 1 |
|
|
view_rec["ejection_fractions"][f"EF_{ef:.2f}"] = { |
|
|
"success": False, |
|
|
"error": result.get("error", "Unknown error") |
|
|
} |
|
|
print(f"β {view} EF={ef:.2f} generation failed: {result.get('error', 'Unknown error')}") |
|
|
|
|
|
except Exception as e: |
|
|
results["total_generations"] += 1 |
|
|
view_rec["ejection_fractions"][f"EF_{ef:.2f}"] = { |
|
|
"success": False, |
|
|
"error": str(e) |
|
|
} |
|
|
print(f"β {view} EF={ef:.2f} generation error: {e}") |
|
|
|
|
|
|
|
|
if results["total_generations"] > 0: |
|
|
results["success_rate"] = results["successful_generations"] / results["total_generations"] |
|
|
else: |
|
|
results["success_rate"] = 0.0 |
|
|
|
|
|
print(f"\nπ Generation Summary:") |
|
|
print(f" Total generations: {results['total_generations']}") |
|
|
print(f" Successful: {results['successful_generations']}") |
|
|
print(f" Success rate: {results['success_rate']:.2%}") |
|
|
|
|
|
return results |
|
|
|
|
|
async def _arun( |
|
|
self, |
|
|
views: List[str], |
|
|
ejection_fractions: List[float], |
|
|
num_frames: int = 16, |
|
|
timestep: float = 0.5, |
|
|
outdir: Optional[str] = None, |
|
|
save_features: bool = True, |
|
|
save_metadata: bool = True, |
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, |
|
|
) -> Dict[str, Any]: |
|
|
return self._run( |
|
|
views=views, |
|
|
ejection_fractions=ejection_fractions, |
|
|
num_frames=num_frames, |
|
|
timestep=timestep, |
|
|
outdir=outdir, |
|
|
save_features=save_features, |
|
|
save_metadata=save_metadata, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def create_echoflow_tool(device: Optional[str] = None, temp_dir: Optional[str] = None) -> EchoFlowGenerationTool: |
|
|
""" |
|
|
Create an EchoFlow generation tool. |
|
|
|
|
|
Args: |
|
|
device: Device to use ('cuda', 'cpu', or None for auto-detection) |
|
|
temp_dir: Temporary directory for outputs |
|
|
|
|
|
Returns: |
|
|
Initialized EchoFlowGenerationTool instance |
|
|
""" |
|
|
return EchoFlowGenerationTool(device=device, temp_dir=temp_dir) |
|
|
|
|
|
def test_echoflow_tool(): |
|
|
"""Test the EchoFlow tool.""" |
|
|
print("π§ͺ Testing EchoFlow Tool") |
|
|
print("=" * 40) |
|
|
|
|
|
try: |
|
|
|
|
|
tool = create_echoflow_tool() |
|
|
|
|
|
|
|
|
output_dir = Path(__file__).resolve().parents[2] / "temp" / "echoflow_test_output" |
|
|
result = tool.run({ |
|
|
"views": ["A4C", "PLAX"], |
|
|
"ejection_fractions": [0.35, 0.65], |
|
|
"num_frames": 8, |
|
|
"timestep": 0.5, |
|
|
"outdir": str(output_dir), |
|
|
"save_features": True, |
|
|
"save_metadata": True |
|
|
}) |
|
|
|
|
|
print(f"\nπ Test Results:") |
|
|
print(f" Success: {result['success']}") |
|
|
print(f" Success rate: {result['success_rate']:.2%}") |
|
|
print(f" Output directory: {result['outdir']}") |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β EchoFlow tool test failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
test_echoflow_tool() |
|
|
|