| import torch |
| import torch.nn as nn |
| import torch.distributed as dist |
| from typing import List, Optional, Dict, Any, Tuple |
| import logging |
| import os |
| from contextlib import contextmanager |
|
|
| from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy |
| from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
| try: |
| from torch.distributed.pipeline.sync import Pipe |
| from torch.distributed._pipeline.sync import balance |
| except Exception: |
| Pipe = None |
| balance = None |
|
|
| from .model import BitTransformerLM, LoggingTransformerEncoderLayer |
| from .error_handling import with_error_recovery, safe_operation |
| from .types import DeviceType, WorldSize, ProcessRank |
|
|
|
|
| @with_error_recovery(max_retries=2) |
| def setup_distributed(rank: ProcessRank = 0, |
| world_size: WorldSize = 1, |
| backend: str = "nccl", |
| init_method: str = "tcp://localhost:23456") -> bool: |
| """Initialize distributed training environment.""" |
| if world_size <= 1: |
| return False |
| |
| try: |
| dist.init_process_group( |
| backend=backend, |
| init_method=init_method, |
| world_size=world_size, |
| rank=rank |
| ) |
| logging.info(f"Initialized distributed training: rank {rank}/{world_size}") |
| return True |
| except Exception as e: |
| logging.error(f"Failed to initialize distributed training: {e}") |
| return False |
|
|
|
|
| def wrap_fsdp(model: BitTransformerLM, |
| sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD, |
| **kwargs) -> FullyShardedDataParallel: |
| """Return an optimized FSDP wrapped model with transformer-aware sharding.""" |
| device = kwargs.pop("device_id", None) |
| if device is None and torch.cuda.is_available(): |
| device = torch.cuda.current_device() |
| |
| |
| fsdp_config = { |
| "sharding_strategy": sharding_strategy, |
| "cpu_offload": kwargs.pop("cpu_offload", None), |
| "mixed_precision": kwargs.pop("mixed_precision", None), |
| "auto_wrap_policy": transformer_auto_wrap_policy, |
| "backward_prefetch": kwargs.pop("backward_prefetch", None), |
| "forward_prefetch": kwargs.pop("forward_prefetch", False), |
| "limit_all_gathers": kwargs.pop("limit_all_gathers", True), |
| "use_orig_params": kwargs.pop("use_orig_params", True), |
| **kwargs |
| } |
| |
| |
| fsdp_config = {k: v for k, v in fsdp_config.items() if v is not None} |
| |
| if device is not None: |
| model = model.to(device) |
| fsdp_config["device_id"] = device |
| |
| return FullyShardedDataParallel(model, **fsdp_config) |
|
|
|
|
| class OptimizedPipeline(nn.Module): |
| """Enhanced pipeline parallelism with BitTransformerLM optimizations.""" |
| |
| def __init__(self, |
| model: BitTransformerLM, |
| num_stages: int = 1, |
| chunks: int = 1, |
| checkpoint: bool = True): |
| super().__init__() |
| |
| if Pipe is None: |
| raise RuntimeError("Pipeline parallelism not available in this build") |
| |
| self.num_stages = num_stages |
| self.chunks = chunks |
| self.checkpoint = checkpoint |
| |
| |
| if num_stages > 1: |
| self.pipeline_model = self._create_pipeline_stages(model, num_stages) |
| else: |
| self.pipeline_model = Pipe(nn.Sequential(model), chunks=chunks) |
| |
| def _create_pipeline_stages(self, model: BitTransformerLM, num_stages: int) -> Pipe: |
| """Create optimized pipeline stages for BitTransformerLM.""" |
| |
| layers = [] |
| |
| |
| if hasattr(model, 'embedding'): |
| layers.append(model.embedding) |
| if hasattr(model, 'pos_encoding'): |
| layers.append(model.pos_encoding) |
| |
| |
| if hasattr(model, 'layers'): |
| layers.extend(model.layers) |
| elif hasattr(model, 'transformer'): |
| layers.extend(model.transformer.layers) |
| |
| |
| if hasattr(model, 'output_projection'): |
| layers.append(model.output_projection) |
| |
| |
| if balance is not None: |
| partitions = balance(len(layers), num_stages) |
| else: |
| |
| layers_per_stage = len(layers) // num_stages |
| partitions = [layers_per_stage] * num_stages |
| partitions[-1] += len(layers) % num_stages |
| |
| |
| stages = [] |
| start_idx = 0 |
| for partition_size in partitions: |
| end_idx = start_idx + partition_size |
| stage_layers = layers[start_idx:end_idx] |
| stages.append(nn.Sequential(*stage_layers)) |
| start_idx = end_idx |
| |
| return Pipe(nn.Sequential(*stages), chunks=self.chunks) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Forward pass through pipeline.""" |
| return self.pipeline_model(x) |
|
|
|
|
| def make_pipeline(model: BitTransformerLM, |
| chunks: int = 1, |
| num_stages: int = 1, |
| checkpoint: bool = True) -> OptimizedPipeline: |
| """Create an optimized pipeline with advanced parallelism features.""" |
| return OptimizedPipeline( |
| model=model, |
| num_stages=num_stages, |
| chunks=chunks, |
| checkpoint=checkpoint |
| ) |
|
|
|
|
| class DistributedTrainingManager: |
| """Manages distributed training configuration and optimization.""" |
| |
| def __init__(self, |
| world_size: WorldSize, |
| rank: ProcessRank, |
| use_pipeline: bool = False, |
| use_fsdp: bool = True): |
| self.world_size = world_size |
| self.rank = rank |
| self.use_pipeline = use_pipeline |
| self.use_fsdp = use_fsdp |
| self.is_distributed = world_size > 1 |
| |
| self.logger = logging.getLogger(__name__) |
| |
| def setup_model(self, |
| model: BitTransformerLM, |
| pipeline_stages: int = 1, |
| fsdp_config: Optional[Dict[str, Any]] = None) -> nn.Module: |
| """Set up model for distributed training.""" |
| if not self.is_distributed: |
| return model |
| |
| with safe_operation("distributed_model_setup"): |
| if self.use_pipeline and pipeline_stages > 1: |
| self.logger.info(f"Setting up pipeline parallelism with {pipeline_stages} stages") |
| return make_pipeline( |
| model, |
| chunks=2, |
| num_stages=pipeline_stages |
| ) |
| |
| elif self.use_fsdp: |
| self.logger.info("Setting up FSDP for data parallelism") |
| fsdp_config = fsdp_config or {} |
| return wrap_fsdp(model, **fsdp_config) |
| |
| else: |
| self.logger.info("Using standard DistributedDataParallel") |
| return nn.parallel.DistributedDataParallel(model) |
| |
| def optimize_communication(self, model: nn.Module) -> None: |
| """Apply communication optimizations for distributed training.""" |
| if not self.is_distributed: |
| return |
| |
| |
| if isinstance(model, nn.parallel.DistributedDataParallel): |
| |
| model._set_ddp_bucket_cap_mb(25) |
| |
| |
| try: |
| if hasattr(model, '_register_comm_hook'): |
| from torch.distributed.algorithms.ddp_comm_hooks import default |
| model.register_comm_hook( |
| dist.group.WORLD, |
| default.fp16_compress_hook |
| ) |
| except ImportError: |
| pass |
| |
| @contextmanager |
| def training_context(self): |
| """Context manager for distributed training setup.""" |
| try: |
| if self.is_distributed: |
| self.logger.info("Entering distributed training context") |
| |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(self.rank) |
| yield |
| finally: |
| if self.is_distributed: |
| self.logger.info("Exiting distributed training context") |
|
|
|
|
| def cleanup_distributed(): |
| """Clean up distributed training environment.""" |
| if dist.is_initialized(): |
| dist.destroy_process_group() |
| logging.info("Distributed training cleaned up") |
|
|
|
|
| def get_distributed_config() -> Dict[str, Any]: |
| """Get current distributed training configuration.""" |
| if not dist.is_initialized(): |
| return {"distributed": False} |
| |
| return { |
| "distributed": True, |
| "world_size": dist.get_world_size(), |
| "rank": dist.get_rank(), |
| "backend": dist.get_backend(), |
| "local_rank": int(os.environ.get("LOCAL_RANK", 0)) if "LOCAL_RANK" in os.environ else None, |
| } |
|
|
|
|
| |
| def all_reduce_tensor(tensor: torch.Tensor, |
| op: dist.ReduceOp = dist.ReduceOp.SUM) -> torch.Tensor: |
| """All-reduce operation on tensor across all processes.""" |
| if not dist.is_initialized(): |
| return tensor |
| |
| dist.all_reduce(tensor, op=op) |
| return tensor |
|
|
|
|
| def gather_tensors(tensor: torch.Tensor, |
| dst: int = 0) -> Optional[List[torch.Tensor]]: |
| """Gather tensors from all processes to destination rank.""" |
| if not dist.is_initialized(): |
| return [tensor] |
| |
| if dist.get_rank() == dst: |
| tensor_list = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] |
| dist.gather(tensor, tensor_list, dst=dst) |
| return tensor_list |
| else: |
| dist.gather(tensor, dst=dst) |
| return None |
|
|
|
|
| def broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor: |
| """Broadcast tensor from source rank to all processes.""" |
| if not dist.is_initialized(): |
| return tensor |
| |
| dist.broadcast(tensor, src=src) |
| return tensor |
|
|
|
|
| |
| class PipelineScheduler: |
| """Advanced scheduler for pipeline parallelism with load balancing.""" |
| |
| def __init__(self, num_stages: int, world_size: int): |
| self.num_stages = num_stages |
| self.world_size = world_size |
| self.stage_times = [0.0] * num_stages |
| self.load_balance_enabled = True |
| |
| def update_stage_timing(self, stage_id: int, execution_time: float): |
| """Update execution time for a pipeline stage.""" |
| if 0 <= stage_id < self.num_stages: |
| |
| alpha = 0.1 |
| self.stage_times[stage_id] = (1 - alpha) * self.stage_times[stage_id] + alpha * execution_time |
| |
| def get_optimal_chunks(self, batch_size: int) -> int: |
| """Calculate optimal number of chunks based on stage timing.""" |
| if not self.load_balance_enabled: |
| return max(1, batch_size // 8) |
| |
| |
| max_stage_time = max(self.stage_times) if any(self.stage_times) else 1.0 |
| avg_stage_time = sum(self.stage_times) / len(self.stage_times) if self.stage_times else 1.0 |
| |
| |
| imbalance_factor = max_stage_time / max(avg_stage_time, 1e-6) |
| optimal_chunks = max(2, min(batch_size, int(4 * imbalance_factor))) |
| |
| return optimal_chunks |
|
|
|
|
| |
| def efficient_gradient_sync(model: nn.Module, gradient_clipping: float = 1.0): |
| """Perform memory-efficient gradient synchronization across processes.""" |
| if not dist.is_initialized(): |
| return |
| |
| |
| if gradient_clipping > 0: |
| total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) |
| |
| |
| if dist.get_rank() == 0: |
| logging.debug(f"Gradient norm before clipping: {total_norm.item():.4f}") |
| |
| |
| bucket_size_mb = 25 |
| parameters = list(model.parameters()) |
| |
| for param in parameters: |
| if param.grad is not None: |
| |
| dist.all_reduce(param.grad, async_op=False) |
| param.grad /= dist.get_world_size() |
|
|
|
|
| |
| class DistributedMemoryManager: |
| """Manages memory efficiently across distributed processes.""" |
| |
| def __init__(self, enable_cpu_offload: bool = False): |
| self.enable_cpu_offload = enable_cpu_offload |
| self.memory_stats = {} |
| self.peak_memory = 0 |
| |
| def monitor_memory(self): |
| """Monitor GPU memory usage across processes.""" |
| if torch.cuda.is_available(): |
| current_memory = torch.cuda.memory_allocated() |
| max_memory = torch.cuda.max_memory_allocated() |
| |
| self.memory_stats = { |
| "current_gb": current_memory / 1e9, |
| "peak_gb": max_memory / 1e9, |
| "rank": dist.get_rank() if dist.is_initialized() else 0 |
| } |
| |
| self.peak_memory = max(self.peak_memory, current_memory) |
| |
| def optimize_memory_usage(self): |
| """Apply memory optimizations based on current usage.""" |
| if torch.cuda.is_available(): |
| |
| if torch.cuda.memory_allocated() > 0.8 * torch.cuda.max_memory_allocated(): |
| torch.cuda.empty_cache() |
| logging.info("Cleared CUDA cache due to high memory usage") |
| |
| def get_memory_report(self) -> Dict[str, float]: |
| """Get comprehensive memory usage report.""" |
| self.monitor_memory() |
| return self.memory_stats |
|
|
|
|
| |
| pipeline_scheduler = PipelineScheduler(num_stages=1, world_size=1) |
| memory_manager = DistributedMemoryManager() |
|
|
|
|
| def setup_advanced_distributed_training( |
| rank: ProcessRank, |
| world_size: WorldSize, |
| enable_memory_monitoring: bool = True, |
| enable_pipeline_scheduling: bool = True |
| ) -> Dict[str, Any]: |
| """Set up advanced distributed training with optimizations.""" |
| global pipeline_scheduler, memory_manager |
| |
| |
| success = setup_distributed(rank, world_size) |
| if not success: |
| return {"distributed": False} |
| |
| |
| if enable_pipeline_scheduling: |
| pipeline_scheduler = PipelineScheduler(num_stages=world_size, world_size=world_size) |
| |
| if enable_memory_monitoring: |
| memory_manager = DistributedMemoryManager() |
| memory_manager.monitor_memory() |
| |
| config = get_distributed_config() |
| config.update({ |
| "pipeline_scheduling": enable_pipeline_scheduling, |
| "memory_monitoring": enable_memory_monitoring, |
| "advanced_features": True |
| }) |
| |
| logging.info(f"Advanced distributed training initialized on rank {rank}") |
| return config |
|
|