| | import torch |
| | import torch.nn as nn |
| | from typing import Any, Dict, List, Tuple |
| |
|
| | def inspect_model_devices(model, prefix="", max_depth=3, current_depth=0): |
| | """ |
| | Recursively inspect all properties of a PyTorch Lightning model to find |
| | which tensors are on CPU vs CUDA devices and check density/sparsity. |
| | |
| | Args: |
| | model: The model/object to inspect |
| | prefix: String prefix for nested attributes |
| | max_depth: Maximum recursion depth to prevent infinite loops |
| | current_depth: Current recursion depth |
| | |
| | Returns: |
| | Dict with categorized results |
| | """ |
| | results = { |
| | 'cuda_tensors': [], |
| | 'cpu_tensors': [], |
| | 'mixed_tensors': [], |
| | 'sparse_tensors': [], |
| | 'non_contiguous_tensors': [], |
| | 'problematic_tensors': [], |
| | 'non_tensor_attrs': [], |
| | 'errors': [] |
| | } |
| | |
| | if current_depth >= max_depth: |
| | return results |
| | |
| | |
| | for attr_name in dir(model): |
| | |
| | if attr_name.startswith('_'): |
| | continue |
| | |
| | try: |
| | attr_value = getattr(model, attr_name) |
| | full_name = f"{prefix}.{attr_name}" if prefix else attr_name |
| | |
| | |
| | if isinstance(attr_value, torch.Tensor): |
| | |
| | is_sparse = attr_value.is_sparse or attr_value.is_sparse_csr |
| | is_contiguous = attr_value.is_contiguous() |
| | is_cuda = attr_value.device.type == 'cuda' |
| | |
| | device_info = { |
| | 'name': full_name, |
| | 'shape': tuple(attr_value.shape), |
| | 'dtype': str(attr_value.dtype), |
| | 'device': str(attr_value.device), |
| | 'requires_grad': attr_value.requires_grad, |
| | 'is_sparse': is_sparse, |
| | 'is_contiguous': is_contiguous, |
| | 'stride': tuple(attr_value.stride()) if not is_sparse else 'N/A (sparse)', |
| | 'storage_offset': attr_value.storage_offset() if not is_sparse else 'N/A (sparse)', |
| | 'numel': attr_value.numel() |
| | } |
| | |
| | |
| | has_issues = [] |
| | if not is_cuda: |
| | has_issues.append('CPU') |
| | if is_sparse: |
| | has_issues.append('SPARSE') |
| | results['sparse_tensors'].append(device_info) |
| | if not is_contiguous: |
| | has_issues.append('NON_CONTIGUOUS') |
| | results['non_contiguous_tensors'].append(device_info) |
| | |
| | if has_issues: |
| | device_info['issues'] = has_issues |
| | results['problematic_tensors'].append(device_info) |
| | |
| | |
| | if is_cuda: |
| | results['cuda_tensors'].append(device_info) |
| | else: |
| | results['cpu_tensors'].append(device_info) |
| | |
| | |
| | elif isinstance(attr_value, nn.Parameter): |
| | is_sparse = attr_value.is_sparse or attr_value.is_sparse_csr |
| | is_contiguous = attr_value.is_contiguous() |
| | is_cuda = attr_value.device.type == 'cuda' |
| | |
| | device_info = { |
| | 'name': full_name, |
| | 'shape': tuple(attr_value.shape), |
| | 'dtype': str(attr_value.dtype), |
| | 'device': str(attr_value.device), |
| | 'requires_grad': attr_value.requires_grad, |
| | 'type': 'Parameter', |
| | 'is_sparse': is_sparse, |
| | 'is_contiguous': is_contiguous, |
| | 'stride': tuple(attr_value.stride()) if not is_sparse else 'N/A (sparse)', |
| | 'storage_offset': attr_value.storage_offset() if not is_sparse else 'N/A (sparse)', |
| | 'numel': attr_value.numel() |
| | } |
| | |
| | |
| | has_issues = [] |
| | if not is_cuda: |
| | has_issues.append('CPU') |
| | if is_sparse: |
| | has_issues.append('SPARSE') |
| | results['sparse_tensors'].append(device_info) |
| | if not is_contiguous: |
| | has_issues.append('NON_CONTIGUOUS') |
| | results['non_contiguous_tensors'].append(device_info) |
| | |
| | if has_issues: |
| | device_info['issues'] = has_issues |
| | results['problematic_tensors'].append(device_info) |
| | |
| | if is_cuda: |
| | results['cuda_tensors'].append(device_info) |
| | else: |
| | results['cpu_tensors'].append(device_info) |
| | |
| | |
| | elif isinstance(attr_value, nn.Module): |
| | |
| | module_devices = set() |
| | for param in attr_value.parameters(): |
| | module_devices.add(param.device.type) |
| | for buffer in attr_value.buffers(): |
| | module_devices.add(buffer.device.type) |
| | |
| | if len(module_devices) > 1: |
| | results['mixed_tensors'].append({ |
| | 'name': full_name, |
| | 'type': type(attr_value).__name__, |
| | 'devices': list(module_devices) |
| | }) |
| | elif len(module_devices) == 1: |
| | device_type = list(module_devices)[0] |
| | module_info = { |
| | 'name': full_name, |
| | 'type': type(attr_value).__name__, |
| | 'device': device_type |
| | } |
| | |
| | if device_type == 'cuda': |
| | results['cuda_tensors'].append(module_info) |
| | else: |
| | results['cpu_tensors'].append(module_info) |
| | |
| | |
| | if current_depth < max_depth - 1: |
| | sub_results = inspect_model_devices(attr_value, full_name, max_depth, current_depth + 1) |
| | for key in results: |
| | results[key].extend(sub_results[key]) |
| | |
| | |
| | elif hasattr(attr_value, '__dict__') and not callable(attr_value): |
| | if current_depth < max_depth - 1: |
| | sub_results = inspect_model_devices(attr_value, full_name, max_depth, current_depth + 1) |
| | for key in results: |
| | results[key].extend(sub_results[key]) |
| | else: |
| | |
| | if not callable(attr_value): |
| | results['non_tensor_attrs'].append({ |
| | 'name': full_name, |
| | 'type': type(attr_value).__name__, |
| | 'value': str(attr_value)[:100] |
| | }) |
| | |
| | except Exception as e: |
| | results['errors'].append({ |
| | 'name': full_name, |
| | 'error': str(e) |
| | }) |
| | continue |
| | |
| | return results |
| |
|
| | def print_device_report(model, detailed=False): |
| | """ |
| | Print a formatted report of device allocation for all model components. |
| | |
| | Args: |
| | model: PyTorch Lightning model to inspect |
| | detailed: If True, show detailed information for each tensor |
| | """ |
| | print("="*80) |
| | print("COMPREHENSIVE DEVICE & TENSOR DENSITY REPORT") |
| | print("="*80) |
| | |
| | results = inspect_model_devices(model) |
| | |
| | |
| | if results['problematic_tensors']: |
| | print(f"\n🚨 PROBLEMATIC TENSORS ({len(results['problematic_tensors'])}) - LIKELY CAUSING ISSUES!") |
| | print("-" * 70) |
| | for item in results['problematic_tensors']: |
| | issues_str = " | ".join(item['issues']) |
| | print(f" ❌ {item['name']}: {issues_str}") |
| | if detailed: |
| | print(f" Shape: {item['shape']} | Device: {item['device']}") |
| | print(f" Contiguous: {item['is_contiguous']} | Sparse: {item['is_sparse']}") |
| | if item['stride'] != 'N/A (sparse)': |
| | print(f" Stride: {item['stride']} | Storage offset: {item['storage_offset']}") |
| | print() |
| | |
| | print(f"\n📍 CUDA TENSORS/MODULES ({len(results['cuda_tensors'])})") |
| | print("-" * 40) |
| | for item in results['cuda_tensors']: |
| | status_indicators = [] |
| | if not item.get('is_contiguous', True): |
| | status_indicators.append('NON-CONTIGUOUS') |
| | if item.get('is_sparse', False): |
| | status_indicators.append('SPARSE') |
| | |
| | status_str = f" [{', '.join(status_indicators)}]" if status_indicators else "" |
| | |
| | if detailed: |
| | if 'shape' in item: |
| | print(f" {item['name']}: {item['shape']} | {item['dtype']} | {item['device']}{status_str}") |
| | if 'stride' in item and item['stride'] != 'N/A (sparse)': |
| | print(f" ↳ Contiguous: {item.get('is_contiguous', 'N/A')} | Stride: {item['stride']}") |
| | else: |
| | print(f" {item['name']}: {item['type']} | {item['device']}{status_str}") |
| | else: |
| | print(f" ✅ {item['name']}{status_str}") |
| | |
| | print(f"\n💻 CPU TENSORS/MODULES ({len(results['cpu_tensors'])})") |
| | print("-" * 40) |
| | for item in results['cpu_tensors']: |
| | status_indicators = [] |
| | if not item.get('is_contiguous', True): |
| | status_indicators.append('NON-CONTIGUOUS') |
| | if item.get('is_sparse', False): |
| | status_indicators.append('SPARSE') |
| | |
| | status_str = f" [{', '.join(status_indicators)}]" if status_indicators else "" |
| | |
| | if detailed: |
| | if 'shape' in item: |
| | print(f" {item['name']}: {item['shape']} | {item['dtype']} | {item['device']}{status_str}") |
| | if 'stride' in item and item['stride'] != 'N/A (sparse)': |
| | print(f" ↳ Contiguous: {item.get('is_contiguous', 'N/A')} | Stride: {item['stride']}") |
| | else: |
| | print(f" {item['name']}: {item['type']} | {item['device']}{status_str}") |
| | else: |
| | print(f" ❌ {item['name']}{status_str}") |
| | |
| | if results['sparse_tensors']: |
| | print(f"\n🕳️ SPARSE TENSORS ({len(results['sparse_tensors'])})") |
| | print("-" * 40) |
| | for item in results['sparse_tensors']: |
| | print(f" {item['name']}: {item['shape']} | {item['device']} | SPARSE") |
| | |
| | if results['non_contiguous_tensors']: |
| | print(f"\n📐 NON-CONTIGUOUS TENSORS ({len(results['non_contiguous_tensors'])})") |
| | print("-" * 40) |
| | for item in results['non_contiguous_tensors']: |
| | print(f" {item['name']}: {item['shape']} | Stride: {item['stride']}") |
| | if detailed: |
| | print(f" ↳ Storage offset: {item['storage_offset']}") |
| | |
| | if results['mixed_tensors']: |
| | print(f"\n⚠️ MIXED DEVICE MODULES ({len(results['mixed_tensors'])})") |
| | print("-" * 40) |
| | for item in results['mixed_tensors']: |
| | print(f" {item['name']}: {item['type']} | Devices: {item['devices']}") |
| | |
| | if results['errors']: |
| | print(f"\n❗ ERRORS ({len(results['errors'])})") |
| | print("-" * 40) |
| | for item in results['errors']: |
| | print(f" {item['name']}: {item['error']}") |
| | |
| | print(f"\n📊 DETAILED SUMMARY") |
| | print("-" * 40) |
| | print(f" CUDA components: {len(results['cuda_tensors'])}") |
| | print(f" CPU components: {len(results['cpu_tensors'])}") |
| | print(f" Sparse tensors: {len(results['sparse_tensors'])}") |
| | print(f" Non-contiguous: {len(results['non_contiguous_tensors'])}") |
| | print(f" Mixed components: {len(results['mixed_tensors'])}") |
| | print(f" Errors: {len(results['errors'])}") |
| | print(f" 🚨 TOTAL PROBLEMATIC: {len(results['problematic_tensors'])}") |
| | |
| | if results['problematic_tensors']: |
| | print(f"\n⚠️ CRITICAL: Found {len(results['problematic_tensors'])} problematic tensors!") |
| | print(" These are likely causing the 'Tensors must be CUDA and dense' error.") |
| | print(" Focus on fixing the tensors marked with 🚨 above.") |