| import math |
| import multiprocessing |
| import os |
| from datetime import timedelta |
| from functools import partial |
| from itertools import chain |
|
|
| import torch |
| |
|
|
| from torch.distributed.fsdp import ( |
| FullyShardedDataParallel, |
| MixedPrecision, |
| BackwardPrefetch, |
| ShardingStrategy, |
| ) |
| from accelerate import Accelerator |
| from accelerate.utils import (DummyOptim, InitProcessGroupKwargs) |
| from accelerate.logging import get_logger |
|
|
|
|
| from datasets import load_dataset |
| from lion_pytorch import Lion |
| from torch.nn import LayerNorm |
|
|
|
|
| from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
| CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper) |
| from torch.distributed.fsdp.wrap import ( |
| transformer_auto_wrap_policy |
| ) |
|
|
|
|
| from torch.optim import AdamW |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import (AutoTokenizer, default_data_collator, |
| get_cosine_schedule_with_warmup, |
| get_linear_schedule_with_warmup, set_seed) |
|
|
|
|
| from Andromeda.utils.stable_adamw import StableAdamWUnfused |
| from Andromeda.core.transformer import Transformer, AndromedaEmbedding |
| |
| from Andromeda.model import AndromedaEmbedding |
| from Andromeda.configs import Andromeda1Billion |
|
|
| |
| import torch.distributed as dist |
|
|
|
|
| from accelerate.state import AcceleratorState |
|
|
| |
|
|
|
|
| logger = get_logger(__name__, log_level="INFO") |
|
|
| class CFG: |
| BATCH_SIZE = 1 |
| GRADIENT_ACCUMULATE_EVERY: int = 1 |
| SEED: int = 42 |
| LEARNING_RATE: float = 1e-4 |
| WEIGHT_DECAY: float = 0.1 |
| SEQ_LEN: int = 8192 |
| NUM_CPU: int = multiprocessing.cpu_count() |
| USE_DEEPSPEED: bool = True |
| USE_FSDP: bool = True |
| USE_PRETOKENIZED: bool = True |
| USE_ACTIVATION_CHECKPOINTING: bool = True |
| RESUME_FROM_CHECKPOINT: str = False |
| CHECKPOINTING_STEPS: int = 1000 |
| OUTPUT_DIR: str = 'checkpoints/' |
| ENTITY_NAME: str = "Andromeda" |
| LOGGING_STEPS: int = 100 |
|
|
|
|
| |
|
|
|
|
| def print_num_params(model, accelerator: Accelerator): |
| |
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| accelerator.print(f"Number of parameters in model: {n_params}") |
|
|
|
|
| |
|
|
|
|
| def activation_checkpointing( |
| model: torch.nn.Module, |
| offload_to_cpu: bool = False, |
| accelerator: Accelerator = None, |
| ): |
| """ |
| Apply activation checkpointing to a model. |
| |
| Args: |
| model (Module): The model to which to apply activation checkpointing. |
| offload_to_cpu (bool, optional): Whether to offload the activations to CPU. Defaults to False. |
| accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None. |
| """ |
| if accelerator is not None: |
| accelerator.print("Using activation checkpointing") |
| def check_fn(submodule): |
| return isinstance(submodule, Transformer) |
| non_reentrant_wrapper = partial( |
| checkpoint_wrapper, |
| offload_to_cpu=offload_to_cpu, |
| checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
| ) |
| apply_activation_checkpointing( |
| model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn |
| ) |
|
|
|
|
| |
|
|
|
|
| def fsdp( |
| model: torch.nn.Module, |
| auto_wrap: bool = False, |
| mp: str = "fp32", |
| shard_strat: str = "NO_SHARD", |
| ): |
| """ |
| This function wraps a given PyTorch model with the FullyShardedDataParallel (FSDP) wrapper to enable efficient data parallelism and model sharding. |
| |
| Args: |
| model (torch.nn.Module): The original PyTorch model to be wrapped with FSDP. |
| auto_wrap (bool, optional): If True, it enables automatic wrapping of the model's layers according to the transformer_auto_wrap_policy. Default is False. |
| mp (str, optional): The mixed precision mode to be used. Can be 'bf16' for BFloat16, 'fp16' for Float16 or 'fp32' for Float32 precision. Default is 'fp32'. |
| shard_strat (str, optional): The sharding strategy to be used. Can be 'SHARD_GRAD' for sharding at gradient computation, 'FULL_SHARD' for full model sharding or 'NO_SHARD' for no sharding. Default is 'NO_SHARD'. |
| |
| Raises: |
| ValueError: If the provided mp (mixed precision mode) is not 'bf16', 'fp16' or 'fp32'. |
| ValueError: If the provided shard_strat (sharding strategy) is not 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD'. |
| |
| Returns: |
| torch.nn.Module: The input model wrapped with FSDP. |
| """ |
| if auto_wrap: |
| Andromeda_auto_wrap_policy = partial( |
| transformer_auto_wrap_policy, |
| transformer_layer_cls={ |
| Transformer, |
| }, |
| ) |
| else: |
| Andromeda_auto_wrap_policy = None |
|
|
| if mp == "bf16": |
| mp_fsdp = MixedPrecision( |
| param_dtype=torch.bfloat16, |
| |
| reduce_dtype=torch.bfloat16, |
| |
| buffer_dtype=torch.bfloat16, |
| ) |
| elif mp == "fp16": |
| mp_fsdp = MixedPrecision( |
| param_dtype=torch.float16, |
| |
| reduce_dtype=torch.float16, |
| |
| buffer_dtype=torch.float16, |
| ) |
| elif mp == "fp32": |
| mp_fsdp = MixedPrecision( |
| param_dtype=torch.float32, |
| |
| reduce_dtype=torch.float32, |
| |
| buffer_dtype=torch.float32, |
| ) |
| else: |
| raise ValueError( |
| "Invalid scheduler_type. Expected 'bf16', 'fp16' or 'fp32', got: {}".format( |
| mp |
| ) |
| ) |
|
|
| if shard_strat == "SHARD_GRAD": |
| sharding_strat_fsdp = ShardingStrategy.SHARD_GRAD_OP |
| elif shard_strat == "FULL_SHARD": |
| sharding_strat_fsdp = ShardingStrategy.FULL_SHARD |
| elif shard_strat == "NO_SHARD": |
| sharding_strat_fsdp = ShardingStrategy.NO_SHARD |
| else: |
| raise ValueError( |
| "Invalid scheduler_type. Expected 'SHARD_GRAD', 'FULL_SHARD' or 'NO_SHARD', got: {}".format( |
| shard_strat |
| ) |
| ) |
|
|
| model = FullyShardedDataParallel( |
| model, |
| auto_wrap_policy=Andromeda_auto_wrap_policy, |
| mixed_precision=mp_fsdp, |
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, |
| sharding_strategy=sharding_strat_fsdp, |
| forward_prefetch=True, |
| use_orig_params=True, |
| ) |
|
|
| return model |
|
|
|
|
| |
|
|
|
|
| def get_lr_scheduler_with_warmup( |
| optimizer: torch.optim.Optimizer, |
| scheduler_type: str, |
| num_warmup_steps: int, |
| max_train_steps: int, |
| grad_accumulate_every: int = 1, |
| accelerator: Accelerator = None, |
| ): |
| """ |
| Get a learning rate scheduler with warmup. |
| |
| Args: |
| optimizer (Optimizer): The optimizer for which to create the learning rate scheduler. |
| scheduler_type (str): The type of learning rate scheduler to create, either "linear" or "cosine". |
| num_warmup_steps (int): The number of warmup steps for the learning rate scheduler. |
| max_train_steps (int): The maximum number of training steps. |
| grad_accumulate_every (int, optional): The gradient accumulation factor. Defaults to 1. |
| accelerator (Accelerator, optional): The Accelerate library accelerator. Defaults to None. |
| |
| Returns: |
| The learning rate scheduler with warmup. |
| |
| Raises: |
| ValueError: If scheduler_type is not "linear" or "cosine". |
| """ |
| NUM_WARMUP_STEPS = num_warmup_steps |
| GRADIENT_ACCUMULATE_EVERY = grad_accumulate_every |
| if accelerator is not None: |
| accelerator.print(f"Using {scheduler_type} lr scheduler") |
| if scheduler_type == "linear": |
| return get_linear_schedule_with_warmup( |
| optimizer=optimizer, |
| num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY, |
| num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY, |
| ) |
| elif scheduler_type == "cosine": |
| return get_cosine_schedule_with_warmup( |
| optimizer=optimizer, |
| num_warmup_steps=NUM_WARMUP_STEPS * GRADIENT_ACCUMULATE_EVERY, |
| num_training_steps=max_train_steps * GRADIENT_ACCUMULATE_EVERY, |
| ) |
| else: |
| raise ValueError( |
| "Invalid scheduler_type. Expected 'linear' or 'cosine', got: {}".format( |
| scheduler_type |
| ) |
| ) |
|
|
|
|
| |
|
|
|
|
| def decoupled_optimizer( |
| model: torch.nn.Module, |
| learning_rate: float, |
| weight_decay: float, |
| beta_1: float, |
| beta_2: float, |
| optimizer_type: str, |
| use_fsdp: bool = True, |
| accelerator: Accelerator = None, |
| ): |
| """ |
| Decouples the optimizer from the training process. |
| |
| This function sets up the optimizer for the model by creating two groups of parameters: |
| one for weight decay and one without weight decay. Then, it initializes the optimizer |
| with these two groups of parameters. |
| |
| Args: |
| model (Module): The model whose parameters are optimized. |
| learning_rate (float): The learning rate for the optimizer. |
| weight_decay (float): The weight decay for the optimizer. |
| beta_1 (float): The exponential decay rate for the 1st moment estimates. |
| beta_2 (float): The exponential decay rate for the 2nd moment estimates. |
| optimizer_type (str): The type of the optimizer. Can be 'lion', 'adamw', or 'stable_adamw'. |
| use_fsdp (bool, optional): If True, the optimizer will work with fully sharded data parallelism. Defaults to True. |
| accelerator (Accelerator, optional): The accelerator from HuggingFace's Accelerate library. Defaults to None. |
| |
| Returns: |
| Optimizer: The initialized optimizer. |
| |
| Raises: |
| ValueError: If the optimizer type is not 'lion', 'adamw' or 'stable_adamw'. |
| """ |
| accelerator.print(f"Using {optimizer_type} optimizer") |
| |
| param_dict = {} |
| |
| for param_name, param in model.named_parameters(): |
| param_dict[param_name] = param |
|
|
| |
|
|
| |
| no_decay = [] |
|
|
| if use_fsdp: |
| exclude_module = "_fsdp_wrapped_module.token_emb" |
| else: |
| exclude_module = "token_emb" |
|
|
| |
| for module_name, module in model.named_modules(): |
| |
| for ndim in [LayerNorm, torch.nn.Embedding]: |
| if isinstance(module, ndim): |
| |
| if module_name == exclude_module: |
| no_decay.append(f"{module_name}.weight") |
| else: |
| |
| no_decay.append(f"{module_name}.gamma") |
| |
| break |
|
|
| |
| decay = [] |
|
|
| |
| for module_name, module in model.named_modules(): |
| |
| for ndim in [torch.nn.Linear]: |
| if isinstance(module, ndim): |
| |
| decay.append(f"{module_name}.weight") |
| |
| break |
|
|
| |
| |
| |
|
|
| |
| decay_param = [] |
|
|
| if use_fsdp: |
| exclude_param = "_fsdp_wrapped_module.to_logits.weight" |
| else: |
| exclude_param = "to_logits.weight" |
|
|
| |
| for param in decay: |
| |
| |
|
|
| if param != exclude_param: |
| decay_param.append(param_dict[param]) |
|
|
| |
| no_decay_param = [] |
|
|
| |
| for param in no_decay: |
| try: |
| |
| |
| no_decay_param.append(param_dict[param]) |
| except KeyError: |
| |
| pass |
|
|
| |
| |
| |
| grouped_params = [ |
| {"params": decay_param, "weight_decay": weight_decay}, |
| {"params": no_decay_param, "weight_decay": 0.0}, |
| ] |
|
|
| |
| if optimizer_type == "lion": |
| optimizer = Lion(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),) |
| elif optimizer_type == "adamw": |
| optimizer = AdamW(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),) |
| elif optimizer_type == "deepspeed": |
| optimizer = DummyOptim(grouped_params, lr=learning_rate, betas=(beta_1, beta_2),) |
| elif optimizer_type == "stable_adamw": |
| optimizer = StableAdamWUnfused( |
| grouped_params, lr=learning_rate, betas=(beta_1, beta_2), |
| ) |
| |
| |
| |
| |
| else: |
| raise ValueError( |
| "Invalid optimizer_type. Expected 'lion', 'adamw', 'deepspeed' or 'stable_adamw', got: {}".format( |
| optimizer_type |
| ) |
| ) |
|
|
| |
| return optimizer |
|
|
|
|
| |
|
|
|
|
| def build_dataloaders(): |
| """ |
| Build data loaders for training. |
| |
| This function performs the following steps: |
| 1. Load the tokenizer from the pretrained "EleutherAI/gpt-neox-20b" model. |
| 2. Load the "openwebtext" dataset. |
| 3. Tokenize the dataset, adding the end-of-sentence token to each text. |
| 4. Process the tokenized dataset into chunks of a specified block size. |
| |
| Returns: |
| Dataset: The processed dataset ready for training. |
| """ |
| tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") |
| dataset = load_dataset("openwebtext", split="train") |
|
|
| tokenized_dataset = dataset.map( |
| lambda example: tokenizer([t + tokenizer.eos_token for t in example["text"]]), |
| batched=True, |
| num_proc=CFG.NUM_CPU, |
| remove_columns=["text"], |
| ) |
|
|
| block_size = CFG.SEQ_LEN |
|
|
| |
| def group_texts(examples): |
| |
| concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} |
| total_length = len(concatenated_examples[list(examples.keys())[0]]) |
| |
| |
| if total_length >= block_size: |
| total_length = (total_length // block_size) * block_size |
| |
| result = { |
| k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
| for k, t in concatenated_examples.items() |
| } |
| return result |
|
|
| train_dataset = tokenized_dataset.map( |
| group_texts, batched=True, num_proc=CFG.NUM_CPU, |
| ) |
|
|
| return train_dataset |
|
|
| |
| def build_pre_tokenized(): |
| d0 = load_dataset("conceptofmind/c4_0-to-20_neox_with_eos_8k", split="train[:10]") |
| |
| |
| |
| |
| |
| return d0 |
|
|
|
|
|
|
| def Train(): |
| |
|
|
| timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000)) |
|
|
| accelerator = Accelerator( |
| gradient_accumulation_steps=CFG.GRADIENT_ACCUMULATE_EVERY, |
| mixed_precision="fp16", |
| log_with="wandb", |
| kwargs_handlers=[timeout], |
| ) |
|
|
| state = AcceleratorState() |
| |
| state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = CFG.BATCH_SIZE |
|
|
| accelerator.init_trackers( |
| project_name="Andromeda", |
| config={ |
| "batch_size": CFG.BATCH_SIZE, |
| "gradient_accumulate_every": CFG.GRADIENT_ACCUMULATE_EVERY, |
| "learning_rate": CFG.LEARNING_RATE, |
| "seq_len": CFG.SEQ_LEN, |
| }, |
| |
| ) |
|
|
| accelerator.print(f"Total GPUS: {accelerator.num_processes}") |
|
|
| |
|
|
| set_seed(CFG.SEED) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| model = Andromeda1Billion() |
|
|
| print_num_params(model, accelerator) |
|
|
| if CFG.USE_FSDP: |
| model = fsdp( |
| model, |
| mp="fp16", |
| shard_strat="SHARD_GRAD" |
| ) |
|
|
| if CFG.USE_ACTIVATION_CHECKPOINTING: |
| activation_checkpointing(model, accelerator) |
|
|
| model = accelerator.prepare(model) |
|
|
| |
|
|
| if CFG.USE_PRETOKENIZED: |
| train_dataset = build_pre_tokenized() |
| else: |
| train_dataset = build_dataloaders() |
|
|
| train_loader = DataLoader( |
| train_dataset, batch_size=CFG.BATCH_SIZE, collate_fn=default_data_collator, |
| ) |
|
|
|
|
| |
| optim = decoupled_optimizer( |
| model=model, |
| learning_rate=CFG.LEARNING_RATE, |
| weight_decay=CFG.WEIGHT_DECAY, |
| beta_1=0.90, |
| beta_2=0.95, |
| optimizer_type='lion', |
| use_fsdp=True, |
| accelerator=accelerator |
| ) |
|
|
| |
|
|
| max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY) |
| accelerator.print(f"Max train steps: {max_train_steps}") |
|
|
| |
|
|
| NUM_WARMUP_STEPS = int(max_train_steps * 0.01) |
| accelerator.print(f"Num warmup steps: {NUM_WARMUP_STEPS}") |
|
|
| |
| |
| |
| |
| |
| |
| |
| lr_scheduler = get_lr_scheduler_with_warmup( |
| optimizer=optim, |
| scheduler_type="cosine", |
| num_warmup_steps=NUM_WARMUP_STEPS, |
| max_train_steps=max_train_steps, |
| grad_accumulate_every=CFG.GRADIENT_ACCUMULATE_EVERY, |
| ) |
|
|
| |
|
|
| optim, train_loader, lr_scheduler = accelerator.prepare( |
| optim, train_loader, lr_scheduler |
| ) |
|
|
| |
|
|
| accelerator.register_for_checkpointing(lr_scheduler) |
|
|
| |
|
|
| max_train_steps = math.ceil(len(train_loader) / CFG.GRADIENT_ACCUMULATE_EVERY) |
| accelerator.print(f"Max train steps recalculated: {max_train_steps}") |
|
|
| |
|
|
| total_batch_size = ( |
| CFG.BATCH_SIZE * accelerator.num_processes * CFG.GRADIENT_ACCUMULATE_EVERY |
| ) |
| accelerator.print(f"Total batch size: {total_batch_size}") |
|
|
| |
|
|
| progress_bar = tqdm( |
| range(max_train_steps), disable=not accelerator.is_local_main_process |
| ) |
| completed_steps = 0 |
|
|
| if CFG.RESUME_FROM_CHECKPOINT: |
| if CFG.RESUME_FROM_CHECKPOINT is not None or CFG.RESUME_FROM_CHECKPOINT != "": |
| accelerator.print(f"Resuming from checkpoint {CFG.RESUME_FROM_CHECKPOINT}") |
| accelerator.load_state(CFG.RESUME_FROM_CHECKPOINT) |
| path = os.path.basename(CFG.RESUME_FROM_CHECKPOINT) |
| training_difference = os.path.splitext(path)[0] |
|
|
| |
| resume_step = ( |
| int(training_difference.replace("step_", "")) |
| * CFG.GRADIENT_ACCUMULATE_EVERY |
| ) |
|
|
| if CFG.RESUME_FROM_CHECKPOINT and resume_step is not None: |
| train_loader = accelerator.skip_first_batches(train_loader, resume_step) |
| completed_steps += resume_step |
| progress_bar.update(resume_step) |
|
|
| |
|
|
| model.train() |
| for step, batch in enumerate(train_loader): |
| with accelerator.accumulate(model): |
| inputs = batch["input_ids"].to(accelerator.device) |
| loss = model(inputs, return_loss=True) |
| accelerator.backward(loss) |
|
|
| accelerator.log({"loss": loss.item()}, step=step) |
|
|
| if accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(model.parameters(), 1.0) |
|
|
| optim.step() |
| lr_scheduler.step() |
| optim.zero_grad() |
|
|
| if accelerator.sync_gradients: |
| progress_bar.update(1) |
| completed_steps += 1 |
|
|
| if isinstance(CFG.CHECKPOINTING_STEPS, int): |
| if completed_steps % CFG.CHECKPOINTING_STEPS == 0: |
| output_dir = f"step_{completed_steps }" |
| if CFG.OUTPUT_DIR is not None: |
| output_dir = os.path.join(CFG.OUTPUT_DIR, output_dir) |
| accelerator.save_state(output_dir) |
|
|
| if completed_steps >= max_train_steps: |
| break |
|
|
| |
| if CFG.LOGGING_STEPS > 0 and step % CFG.LOGGING_STEPS == 0: |
| logger.info( |
| f"Step: {completed_steps}/{max_train_steps}, Loss: {loss.item():.5f}" |
| ) |
|
|
| |
|
|
| |
| accelerator.end_training() |
|
|
| |
|
|
| |
| if CFG.OUTPUT_DIR is not None: |
| accelerator.wait_for_everyone() |
| unwrapped_model = accelerator.unwrap_model(model) |
| with accelerator.main_process_first(): |
| accelerator.save( |
| unwrapped_model.state_dict(), f"{CFG.OUTPUT_DIR}/final/final_model.pt" |
| ) |
|
|
|
|
| def main(): |
| os.environ['MASTER_ADDR'] |
| os.environ['MASTER_PORT'] |
| |
| |
| |
| |
|
|
| os.environ['RANK'] |
| os.environ['WORLD_SIZE'] |
|
|
| dist.init_process_group(backend='nccl') |
| |
| Train() |
|
|
| if __name__ == '__main__': |
| main() |