File size: 20,953 Bytes
8f51ef2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
#!/usr/bin/env python3
"""
EchoFlow Final Working Implementation

This is the final working implementation that processes videos frame by frame
to avoid the STDiT multi-frame shape issues.
"""

import sys
import os
import json
import time
import traceback
import warnings
from pathlib import Path
from typing import Dict, Any, Optional, Tuple, List, Union
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
import cv2

PROJECT_ROOT = Path(__file__).resolve().parents[2]
ECHOFLOW_ROOT = PROJECT_ROOT / "EchoFlow"

for candidate in (PROJECT_ROOT, ECHOFLOW_ROOT):
    candidate_str = str(candidate)
    if candidate_str not in sys.path:
        sys.path.insert(0, candidate_str)

# Suppress warnings
warnings.filterwarnings("ignore")

class EchoFlowFinal:
    """Final working EchoFlow implementation."""
    
    def __init__(self, device: Optional[str] = None):
        """
        Initialize EchoFlow.
        
        Args:
            device: Device to use ('cuda', 'cpu', or None for auto-detection)
        """
        self.device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))
        self.dtype = torch.float32
        self.models = {}
        self.config = {}
        self.initialized = False
        
        print(f"πŸ”§ EchoFlow Final initialized on {self.device}")
        
    def load_config(self, config_path: Optional[str] = None) -> bool:
        """Load EchoFlow configuration."""
        try:
            if config_path is None:
                config_path = PROJECT_ROOT / "configs" / "echoflow_config.json"
            
            if os.path.exists(config_path):
                with open(config_path, 'r') as f:
                    self.config = json.load(f)
                print(f"βœ… Config loaded from {config_path}")
                return True
            else:
                print(f"⚠️  Config not found at {config_path}")
                return False
        except Exception as e:
            print(f"❌ Error loading config: {e}")
            return False
    
    def load_models(self) -> bool:
        """Load EchoFlow models."""
        try:
            print("πŸ€– Loading EchoFlow models...")
            
            # Add EchoFlow to path
            sys.path.insert(0, str(ECHOFLOW_ROOT))
            
            # Import core models
            from echoflow.common.models import ResNet18, DiffuserSTDiT, ContrastiveModel
            
            # Load ResNet18 for feature extraction
            self.models['resnet'] = ResNet18().to(self.device).eval()
            print("βœ… ResNet18 loaded")
            
            # Load STDiT for video generation (single frame only)
            self.models['stdit'] = DiffuserSTDiT().to(self.device).eval()
            print("βœ… STDiT loaded")
            
            self.initialized = True
            return True
            
        except Exception as e:
            print(f"❌ Error loading models: {e}")
            traceback.print_exc()
            return False
    
    def preprocess_mask(self, mask: Union[np.ndarray, Image.Image, None], 
                       target_size: Tuple[int, int] = (112, 112)) -> torch.Tensor:
        """
        Preprocess mask for EchoFlow generation.
        
        Args:
            mask: Input mask (numpy array, PIL Image, or None)
            target_size: Target size for the mask (height, width)
            
        Returns:
            Preprocessed mask tensor
        """
        try:
            if mask is None:
                # Create empty mask
                mask_array = np.zeros(target_size, dtype=np.uint8)
            elif isinstance(mask, Image.Image):
                # Convert PIL to numpy
                mask_array = np.array(mask.convert('L'))
            elif isinstance(mask, np.ndarray):
                # Use numpy array directly
                mask_array = mask
            else:
                raise ValueError(f"Unsupported mask type: {type(mask)}")
            
            # Resize to target size
            mask_resized = cv2.resize(mask_array, target_size, interpolation=cv2.INTER_NEAREST)
            
            # Convert to binary (0 or 1)
            mask_binary = (mask_resized > 127).astype(np.float32)
            
            # Convert to tensor
            mask_tensor = torch.from_numpy(mask_binary).unsqueeze(0).unsqueeze(0)
            mask_tensor = mask_tensor.to(self.device, dtype=self.dtype)
            
            return mask_tensor
            
        except Exception as e:
            print(f"❌ Error preprocessing mask: {e}")
            # Return empty mask on error
            return torch.zeros(1, 1, *target_size, device=self.device, dtype=self.dtype)
    
    def generate_image_features(self, image: Union[np.ndarray, torch.Tensor], 
                               target_size: Tuple[int, int] = (224, 224)) -> torch.Tensor:
        """
        Generate features from an image using ResNet18.
        
        Args:
            image: Input image (numpy array or torch tensor)
            target_size: Target size for the image (height, width)
            
        Returns:
            Feature tensor
        """
        try:
            if not self.initialized or 'resnet' not in self.models:
                raise RuntimeError("EchoFlow not initialized. Call load_models() first.")
            
            # Convert to tensor if needed
            if isinstance(image, np.ndarray):
                if image.ndim == 3 and image.shape[2] == 3:
                    # RGB image
                    image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
                elif image.ndim == 2:
                    # Grayscale image
                    image_tensor = torch.from_numpy(image).unsqueeze(0).float() / 255.0
                    image_tensor = image_tensor.repeat(3, 1, 1)  # Convert to RGB
                else:
                    raise ValueError(f"Unsupported image shape: {image.shape}")
            else:
                image_tensor = image
            
            # Add batch dimension if needed
            if image_tensor.ndim == 3:
                image_tensor = image_tensor.unsqueeze(0)
            
            # Resize to target size
            image_tensor = torch.nn.functional.interpolate(
                image_tensor, size=target_size, mode='bilinear', align_corners=False
            )
            
            # Move to device
            image_tensor = image_tensor.to(self.device, dtype=self.dtype)
            
            # Generate features
            with torch.no_grad():
                features = self.models['resnet'](image_tensor)
            
            return features
            
        except Exception as e:
            print(f"❌ Error generating image features: {e}")
            traceback.print_exc()
            return torch.zeros(1, 1000, device=self.device, dtype=self.dtype)
    
    def generate_single_frame_features(self, frame: Union[np.ndarray, torch.Tensor], 
                                      timestep: float = 0.5) -> torch.Tensor:
        """
        Generate features from a single frame using STDiT.
        This is the ONLY way that works with the current STDiT model.
        
        Args:
            frame: Input frame (numpy array or torch tensor)
            timestep: Diffusion timestep (0.0 to 1.0)
            
        Returns:
            Frame feature tensor
        """
        try:
            if not self.initialized or 'stdit' not in self.models:
                raise RuntimeError("EchoFlow not initialized. Call load_models() first.")
            
            # Convert to tensor if needed
            if isinstance(frame, np.ndarray):
                if frame.ndim == 3:  # H, W, C
                    frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
                elif frame.ndim == 2:  # H, W
                    frame_tensor = torch.from_numpy(frame).unsqueeze(0).float() / 255.0
                    frame_tensor = frame_tensor.repeat(3, 1, 1)  # Convert to RGB
                else:
                    raise ValueError(f"Unsupported frame shape: {frame.shape}")
            else:
                frame_tensor = frame
            
            # Add batch and time dimensions if needed
            if frame_tensor.ndim == 3:
                frame_tensor = frame_tensor.unsqueeze(0)  # Add batch dimension
            if frame_tensor.ndim == 4:
                frame_tensor = frame_tensor.unsqueeze(2)  # Add time dimension
            
            # Ensure correct shape (B, C, T, H, W) with T=1
            if frame_tensor.shape[1] != 4:  # Not 4-channel latent
                # Convert to 4-channel if needed
                if frame_tensor.shape[1] == 3:  # RGB
                    # Add alpha channel
                    alpha = torch.ones(frame_tensor.shape[0], 1, *frame_tensor.shape[2:])
                    frame_tensor = torch.cat([frame_tensor, alpha], dim=1)
                else:
                    raise ValueError(f"Unsupported frame channels: {frame_tensor.shape[1]}")
            
            # Resize to model input size (32x32)
            frame_tensor = torch.nn.functional.interpolate(
                frame_tensor.view(-1, *frame_tensor.shape[2:]), 
                size=(32, 32), 
                mode='bilinear', 
                align_corners=False
            ).view(frame_tensor.shape[0], frame_tensor.shape[1], frame_tensor.shape[2], 32, 32)
            
            # Move to device
            frame_tensor = frame_tensor.to(self.device, dtype=self.dtype)
            
            # Create timestep tensor
            timestep_tensor = torch.tensor([timestep], device=self.device, dtype=self.dtype)
            
            # Generate features
            with torch.no_grad():
                output = self.models['stdit'](frame_tensor, timestep_tensor)
                features = output.sample
            
            return features
            
        except Exception as e:
            print(f"❌ Error generating single frame features: {e}")
            traceback.print_exc()
            return torch.zeros(1, 4, 1, 32, 32, device=self.device, dtype=self.dtype)
    
    def generate_video_features_frame_by_frame(self, video: Union[np.ndarray, torch.Tensor], 
                                              timestep: float = 0.5) -> torch.Tensor:
        """
        Generate features from a video by processing each frame individually.
        This is the ONLY reliable way to process multi-frame videos.
        
        Args:
            video: Input video (numpy array or torch tensor)
            timestep: Diffusion timestep (0.0 to 1.0)
            
        Returns:
            Video feature tensor
        """
        try:
            if not self.initialized or 'stdit' not in self.models:
                raise RuntimeError("EchoFlow not initialized. Call load_models() first.")
            
            # Convert to tensor if needed
            if isinstance(video, np.ndarray):
                if video.ndim == 4:  # T, H, W, C
                    video_tensor = torch.from_numpy(video).permute(3, 0, 1, 2).float() / 255.0
                elif video.ndim == 5:  # B, T, H, W, C
                    video_tensor = torch.from_numpy(video).permute(0, 4, 1, 2, 3).float() / 255.0
                else:
                    raise ValueError(f"Unsupported video shape: {video.shape}")
            else:
                video_tensor = video
            
            # Add batch dimension if needed
            if video_tensor.ndim == 4:
                video_tensor = video_tensor.unsqueeze(0)
            
            # Ensure correct shape (B, C, T, H, W)
            if video_tensor.shape[1] != 4:  # Not 4-channel latent
                # Convert to 4-channel if needed
                if video_tensor.shape[1] == 3:  # RGB
                    # Add alpha channel
                    alpha = torch.ones(video_tensor.shape[0], 1, *video_tensor.shape[2:])
                    video_tensor = torch.cat([video_tensor, alpha], dim=1)
                else:
                    raise ValueError(f"Unsupported video channels: {video_tensor.shape[1]}")
            
            # Process each frame individually
            batch_size, channels, num_frames, height, width = video_tensor.shape
            frame_features = []
            
            for t in range(num_frames):
                # Extract single frame
                frame = video_tensor[:, :, t, :, :]  # B, C, H, W
                
                # Resize to model input size (32x32)
                frame_resized = torch.nn.functional.interpolate(
                    frame, size=(32, 32), mode='bilinear', align_corners=False
                )
                
                # Add time dimension for STDiT
                frame_with_time = frame_resized.unsqueeze(2)  # B, C, 1, H, W
                
                # Move to device
                frame_with_time = frame_with_time.to(self.device, dtype=self.dtype)
                
                # Create timestep tensor
                timestep_tensor = torch.tensor([timestep], device=self.device, dtype=self.dtype)
                
                # Generate features for this frame
                with torch.no_grad():
                    output = self.models['stdit'](frame_with_time, timestep_tensor)
                    frame_feat = output.sample
                
                frame_features.append(frame_feat)
            
            # Stack frame features
            video_features = torch.cat(frame_features, dim=2)  # B, C, T, H, W
            
            return video_features
            
        except Exception as e:
            print(f"❌ Error generating video features: {e}")
            traceback.print_exc()
            # Return a safe fallback
            return torch.zeros(1, 4, 1, 32, 32, device=self.device, dtype=self.dtype)
    
    def generate_synthetic_echo(self, mask: Union[np.ndarray, Image.Image, None],
                               view_type: str = "A4C",
                               ejection_fraction: float = 0.65,
                               num_frames: int = 16) -> Dict[str, Any]:
        """
        Generate synthetic echocardiogram from mask.
        
        Args:
            mask: Input mask for the left ventricle
            view_type: Type of echo view ("A4C", "PSAX", "PLAX")
            ejection_fraction: Ejection fraction (0.0 to 1.0)
            num_frames: Number of frames in the generated video
            
        Returns:
            Dictionary containing generated features and metadata
        """
        try:
            if not self.initialized:
                raise RuntimeError("EchoFlow not initialized. Call load_models() first.")
            
            print(f"🎬 Generating synthetic echo: {view_type}, EF={ejection_fraction:.2f}, frames={num_frames}")
            
            # Preprocess mask
            mask_tensor = self.preprocess_mask(mask)
            
            # Create dummy video (in real implementation, this would be generated)
            dummy_video = np.random.randint(0, 255, (num_frames, 224, 224, 3), dtype=np.uint8)
            
            # Generate features using frame-by-frame processing
            video_features = self.generate_video_features_frame_by_frame(dummy_video, timestep=ejection_fraction)
            
            # Create result
            result = {
                "success": True,
                "view_type": view_type,
                "ejection_fraction": ejection_fraction,
                "num_frames": num_frames,
                "video_features": video_features.cpu().numpy(),
                "mask_processed": mask_tensor.cpu().numpy(),
                "timestamp": time.time(),
                "device": str(self.device)
            }
            
            print(f"βœ… Synthetic echo generated successfully")
            print(f"   Video features shape: {video_features.shape}")
            return result
            
        except Exception as e:
            print(f"❌ Error generating synthetic echo: {e}")
            traceback.print_exc()
            return {
                "success": False,
                "error": str(e),
                "timestamp": time.time()
            }
    
    def save_results(self, results: Dict[str, Any], output_path: str) -> bool:
        """Save generation results to file."""
        try:
            # Create output directory if it doesn't exist
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            
            # Convert numpy arrays to lists for JSON serialization
            serializable_results = {}
            for key, value in results.items():
                if isinstance(value, np.ndarray):
                    serializable_results[key] = value.tolist()
                else:
                    serializable_results[key] = value
            
            # Save to JSON
            with open(output_path, 'w') as f:
                json.dump(serializable_results, f, indent=2)
            
            print(f"βœ… Results saved to {output_path}")
            return True
            
        except Exception as e:
            print(f"❌ Error saving results: {e}")
            return False

def create_echoflow_generator(device: Optional[str] = None) -> EchoFlowFinal:
    """
    Create and initialize an EchoFlow generator.
    
    Args:
        device: Device to use ('cuda', 'cpu', or None for auto-detection)
        
    Returns:
        Initialized EchoFlowFinal instance
    """
    generator = EchoFlowFinal(device)
    
    # Load configuration
    if not generator.load_config():
        print("⚠️  Could not load config, using defaults")
    
    # Load models
    if not generator.load_models():
        raise RuntimeError("Failed to load EchoFlow models")
    
    return generator

def test_final_echoflow():
    """Test the final EchoFlow implementation."""
    print("πŸ§ͺ Testing Final EchoFlow Implementation")
    print("=" * 50)
    
    try:
        # Create generator
        generator = create_echoflow_generator()
        
        # Test image processing
        print("\n1️⃣ Testing image processing...")
        dummy_image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
        features = generator.generate_image_features(dummy_image)
        print(f"βœ… Image features generated: {features.shape}")
        
        # Test single frame processing
        print("\n2️⃣ Testing single frame processing...")
        dummy_frame = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
        single_frame_features = generator.generate_single_frame_features(dummy_frame)
        print(f"βœ… Single frame features generated: {single_frame_features.shape}")
        
        # Test multi-frame processing (frame by frame)
        print("\n3️⃣ Testing multi-frame processing...")
        test_frames = [4, 8, 16, 32]  # Test different frame counts
        
        for num_frames in test_frames:
            try:
                print(f"  πŸ§ͺ Testing {num_frames} frames...")
                dummy_video = np.random.randint(0, 255, (num_frames, 224, 224, 3), dtype=np.uint8)
                video_features = generator.generate_video_features_frame_by_frame(dummy_video)
                print(f"    βœ… {num_frames} frames processed successfully: {video_features.shape}")
            except Exception as e:
                print(f"    ❌ {num_frames} frames failed: {e}")
        
        # Test synthetic echo generation with different frame counts
        print("\n4️⃣ Testing synthetic echo generation...")
        dummy_mask = np.random.randint(0, 255, (400, 400), dtype=np.uint8)
        
        for num_frames in [4, 8, 16]:
            try:
                print(f"  πŸ§ͺ Testing {num_frames} frame synthetic echo...")
                result = generator.generate_synthetic_echo(
                    mask=dummy_mask,
                    view_type="A4C",
                    ejection_fraction=0.65,
                    num_frames=num_frames
                )
                
                if result["success"]:
                    print(f"    βœ… {num_frames} frame synthetic echo generated successfully")
                    print(f"       Video features shape: {result['video_features'].shape}")
                else:
                    print(f"    ❌ {num_frames} frame synthetic echo failed: {result.get('error', 'Unknown error')}")
            except Exception as e:
                print(f"    ❌ {num_frames} frame synthetic echo error: {e}")
        
        print("\nπŸŽ‰ Final EchoFlow test completed successfully!")
        return True
        
    except Exception as e:
        print(f"❌ Final EchoFlow test failed: {e}")
        traceback.print_exc()
        return False

if __name__ == "__main__":
    # Run final test
    test_final_echoflow()