File size: 2,870 Bytes
d04a061 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import os
import gc
import sys
import torch
import wandb
import torch.nn as nn
import lightning.pytorch as pl
from omegaconf import OmegaConf
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from src.lm.memdlm.diffusion_module import MembraneDiffusion
from src.lm.memdlm.dataloader import MembraneDataModule, get_datasets
from src.utils.model_utils import apply_rdm_freezing
wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')
# Load yaml config
config = OmegaConf.load("/scratch/pranamlab/sgoel/MeMDLM_v2/src/configs/lm.yaml")
# Get datasets
datasets = get_datasets(config)
data_module = MembraneDataModule(
config=config,
train_dataset=datasets['train'],
val_dataset=datasets['val'],
test_dataset=datasets['test'],
)
# Initialize WandB for logging
wandb.init(project=config.wandb.project, name=config.wandb.name)
wandb_logger = WandbLogger(**config.wandb)
# PL checkpoints
lr_monitor = LearningRateMonitor(logging_interval="step")
checkpoint_callback = ModelCheckpoint(
monitor="val/loss",
save_top_k=1,
mode="min",
dirpath=config.checkpointing.save_dir,
filename="best_model",
every_n_train_steps=config.checkpointing.save_every_n_steps
)
# PL trainer
trainer = pl.Trainer(
max_steps=config.training.max_steps,
max_epochs=None, # Ensure training is based on num steps
accelerator="cuda" if torch.cuda.is_available() else "cpu",
devices=config.training.devices if config.training.mode=='train' else [0],
strategy=DDPStrategy(find_unused_parameters=True),
callbacks=[checkpoint_callback, lr_monitor],
logger=wandb_logger,
log_every_n_steps=config.training.log_every_n_steps
)
# Folder to save checkpoints
ckpt_path = config.checkpointing.save_dir
try: os.makedirs(ckpt_path, exist_ok=False)
except FileExistsError: pass
# PL Model for training
diffusion = MembraneDiffusion(config)
diffusion.validate_config()
# Start/resume training or evaluate the model
model_type = "evoflow"
if config.training.mode == "train":
apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type)
trainer.fit(diffusion, datamodule=data_module)
elif config.training.mode == "test":
state_dict = diffusion.get_state_dict(config.checkpointing.best_ckpt_path)
diffusion.load_state_dict(state_dict)
trainer.test(diffusion, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path)
elif config.training.mode == "resume_from_checkpoint":
state_dict = diffusion.get_state_dict(config.training.resume_ckpt_path)
diffusion.load_state_dict(state_dict)
apply_rdm_freezing(diffusion.model, config.training.n_layers, model_type)
trainer.fit(diffusion, datamodule=data_module, ckpt_path=ckpt_path)
wandb.finish()
|