OneMore1's picture
Sync from GitHub FlexiBrain main
6a51385 verified
Raw
History Blame Contribute Delete
7.47 kB
from __future__ import annotations
import argparse
import json
from dataclasses import asdict
import torch
from flexibrain.config import load_config
from flexibrain.engine import DownstreamTrainer, Pretrainer
from flexibrain.models import build_downstream_model, build_pretrain_model
def _add_common(parser):
parser.add_argument("--config", default=None)
parser.add_argument("--train-list", default=None)
parser.add_argument("--val-list", default=None)
parser.add_argument("--test-list", default=None)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--num-workers", type=int, default=None)
parser.add_argument("--t-prime", type=int, default=None)
parser.add_argument("--tau-seconds", type=float, default=None)
parser.add_argument("--default-tr", type=float, default=None, help="Fallback TR in seconds when a NIfTI header has no valid TR.")
parser.add_argument("--epochs", type=int, default=None)
parser.add_argument("--lr", type=float, default=None)
parser.add_argument("--weight-decay", type=float, default=None)
parser.add_argument("--warmup-epochs", type=int, default=None)
parser.add_argument("--grad-accumulation-steps", type=int, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--log-interval", type=int, default=None)
parser.add_argument("--checkpoint-dir", default=None)
parser.add_argument("--log-dir", default=None)
parser.add_argument("--local-rank", type=int, default=None)
parser.add_argument("--world-size", type=int, default=None)
parser.add_argument("--use-amp", action="store_true", default=None)
parser.add_argument("--no-use-amp", dest="use_amp", action="store_false")
parser.add_argument("--dry-run", action="store_true")
parser.set_defaults(use_amp=None)
def _add_model(parser):
parser.add_argument("--embed-dim", type=int, default=None)
parser.add_argument("--depth", type=int, default=None)
parser.add_argument("--predictor-depth", type=int, default=None)
parser.add_argument("--drop-path-rate", type=float, default=None)
parser.add_argument("--bimamba-type", default=None)
parser.add_argument("--if-bimamba", action="store_true", default=None)
parser.add_argument("--if-devide-out", action="store_true", default=None)
parser.add_argument("--no-if-devide-out", dest="if_devide_out", action="store_false")
parser.add_argument("--mixer-type", default=None)
parser.add_argument("--momentum", type=float, default=None)
parser.add_argument("--final-momentum", type=float, default=None)
def apply_common(cfg, args):
for key in ["train_list", "val_list", "test_list", "batch_size", "num_workers", "epochs", "lr", "weight_decay", "warmup_epochs", "grad_accumulation_steps", "seed", "local_rank", "world_size"]:
value = getattr(args, key, None)
if value is None:
continue
target = cfg.data if key in {"train_list", "val_list", "test_list", "batch_size", "num_workers"} else cfg.training
setattr(target, key, value)
if args.t_prime is not None:
cfg.data.T_prime = args.t_prime
if args.tau_seconds is not None:
cfg.data.tau_seconds = args.tau_seconds
if args.default_tr is not None:
cfg.data.default_tr = args.default_tr
if args.use_amp is not None:
cfg.training.use_amp = args.use_amp
if args.log_interval is not None:
cfg.logging.log_interval = args.log_interval
if args.checkpoint_dir is not None:
cfg.logging.checkpoint_dir = args.checkpoint_dir
if args.log_dir is not None:
cfg.logging.log_dir = args.log_dir
def apply_model(cfg, args):
for key in ["embed_dim", "depth", "predictor_depth", "drop_path_rate", "bimamba_type", "if_bimamba", "if_devide_out", "mixer_type", "momentum", "final_momentum"]:
value = getattr(args, key, None)
if value is not None:
setattr(cfg.model, key, value)
def parse_args():
parser = argparse.ArgumentParser(prog="flexibrain")
sub = parser.add_subparsers(dest="command", required=True)
pretrain = sub.add_parser("pretrain")
_add_common(pretrain)
_add_model(pretrain)
pretrain.add_argument("--mask-ratio", type=float, default=None)
pretrain.add_argument("--grad-clip", type=float, default=None)
downstream = sub.add_parser("downstream")
_add_common(downstream)
_add_model(downstream)
downstream.add_argument("--csv", default=None)
downstream.add_argument("--id-column", default=None)
downstream.add_argument("--label-column", default=None)
downstream.add_argument("--label-mode", default=None)
downstream.add_argument("--path-id-mode", default=None)
downstream.add_argument("--pretrain-checkpoint", default=None)
downstream.add_argument("--from-scratch", action="store_true")
downstream.add_argument("--ignore-checkpoint-config", action="store_true")
downstream.add_argument("--num-classes", type=int, default=None)
downstream.add_argument("--head-type", choices=["transformer", "avgpool"], default=None)
downstream.add_argument("--freeze-backbone", action="store_true")
downstream.add_argument("--lr-backbone", type=float, default=None)
downstream.add_argument("--lr-head", type=float, default=None)
return parser.parse_args()
def main():
args = parse_args()
cfg = load_config(args.config)
apply_common(cfg, args)
apply_model(cfg, args)
if args.command == "pretrain":
if args.mask_ratio is not None:
cfg.training.mask_ratio = args.mask_ratio
if args.grad_clip is not None:
cfg.training.grad_clip = args.grad_clip
if args.dry_run:
model = build_pretrain_model(cfg.model, torch.device("cpu"))
print(json.dumps({"config": asdict(cfg), "parameters": sum(p.numel() for p in model.parameters())}, indent=2))
return
Pretrainer(cfg).fit()
elif args.command == "downstream":
for key in ["csv", "id_column", "label_column", "label_mode", "path_id_mode"]:
value = getattr(args, key, None)
if value is not None:
setattr(cfg.data, key, value)
for key in ["num_classes", "head_type", "freeze_backbone"]:
value = getattr(args, key, None)
if value is not None:
setattr(cfg.model, key, value)
if args.lr_backbone is not None:
cfg.training.lr_backbone = args.lr_backbone
if args.lr_head is not None:
cfg.training.lr_head = args.lr_head
if args.pretrain_checkpoint is not None:
cfg.pretrain_checkpoint = args.pretrain_checkpoint
if args.from_scratch:
cfg.from_scratch = True
if args.ignore_checkpoint_config:
cfg.use_checkpoint_config = False
if args.dry_run:
model = build_downstream_model(cfg.model, torch.device("cpu"), checkpoint_path=cfg.pretrain_checkpoint, from_scratch=cfg.from_scratch, use_checkpoint_config=cfg.use_checkpoint_config)
print(json.dumps({"config": asdict(cfg), "parameters": sum(p.numel() for p in model.parameters())}, indent=2))
return
DownstreamTrainer(cfg).fit()
if __name__ == "__main__":
main()