from __future__ import annotations import argparse import json from dataclasses import asdict import torch from flexibrain.config import load_config from flexibrain.engine import DownstreamTrainer, Pretrainer from flexibrain.models import build_downstream_model, build_pretrain_model def _add_common(parser): parser.add_argument("--config", default=None) parser.add_argument("--train-list", default=None) parser.add_argument("--val-list", default=None) parser.add_argument("--test-list", default=None) parser.add_argument("--batch-size", type=int, default=None) parser.add_argument("--num-workers", type=int, default=None) parser.add_argument("--t-prime", type=int, default=None) parser.add_argument("--tau-seconds", type=float, default=None) parser.add_argument("--default-tr", type=float, default=None, help="Fallback TR in seconds when a NIfTI header has no valid TR.") parser.add_argument("--epochs", type=int, default=None) parser.add_argument("--lr", type=float, default=None) parser.add_argument("--weight-decay", type=float, default=None) parser.add_argument("--warmup-epochs", type=int, default=None) parser.add_argument("--grad-accumulation-steps", type=int, default=None) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--log-interval", type=int, default=None) parser.add_argument("--checkpoint-dir", default=None) parser.add_argument("--log-dir", default=None) parser.add_argument("--local-rank", type=int, default=None) parser.add_argument("--world-size", type=int, default=None) parser.add_argument("--use-amp", action="store_true", default=None) parser.add_argument("--no-use-amp", dest="use_amp", action="store_false") parser.add_argument("--dry-run", action="store_true") parser.set_defaults(use_amp=None) def _add_model(parser): parser.add_argument("--embed-dim", type=int, default=None) parser.add_argument("--depth", type=int, default=None) parser.add_argument("--predictor-depth", type=int, default=None) parser.add_argument("--drop-path-rate", type=float, default=None) parser.add_argument("--bimamba-type", default=None) parser.add_argument("--if-bimamba", action="store_true", default=None) parser.add_argument("--if-devide-out", action="store_true", default=None) parser.add_argument("--no-if-devide-out", dest="if_devide_out", action="store_false") parser.add_argument("--mixer-type", default=None) parser.add_argument("--momentum", type=float, default=None) parser.add_argument("--final-momentum", type=float, default=None) def apply_common(cfg, args): for key in ["train_list", "val_list", "test_list", "batch_size", "num_workers", "epochs", "lr", "weight_decay", "warmup_epochs", "grad_accumulation_steps", "seed", "local_rank", "world_size"]: value = getattr(args, key, None) if value is None: continue target = cfg.data if key in {"train_list", "val_list", "test_list", "batch_size", "num_workers"} else cfg.training setattr(target, key, value) if args.t_prime is not None: cfg.data.T_prime = args.t_prime if args.tau_seconds is not None: cfg.data.tau_seconds = args.tau_seconds if args.default_tr is not None: cfg.data.default_tr = args.default_tr if args.use_amp is not None: cfg.training.use_amp = args.use_amp if args.log_interval is not None: cfg.logging.log_interval = args.log_interval if args.checkpoint_dir is not None: cfg.logging.checkpoint_dir = args.checkpoint_dir if args.log_dir is not None: cfg.logging.log_dir = args.log_dir def apply_model(cfg, args): for key in ["embed_dim", "depth", "predictor_depth", "drop_path_rate", "bimamba_type", "if_bimamba", "if_devide_out", "mixer_type", "momentum", "final_momentum"]: value = getattr(args, key, None) if value is not None: setattr(cfg.model, key, value) def parse_args(): parser = argparse.ArgumentParser(prog="flexibrain") sub = parser.add_subparsers(dest="command", required=True) pretrain = sub.add_parser("pretrain") _add_common(pretrain) _add_model(pretrain) pretrain.add_argument("--mask-ratio", type=float, default=None) pretrain.add_argument("--grad-clip", type=float, default=None) downstream = sub.add_parser("downstream") _add_common(downstream) _add_model(downstream) downstream.add_argument("--csv", default=None) downstream.add_argument("--id-column", default=None) downstream.add_argument("--label-column", default=None) downstream.add_argument("--label-mode", default=None) downstream.add_argument("--path-id-mode", default=None) downstream.add_argument("--pretrain-checkpoint", default=None) downstream.add_argument("--from-scratch", action="store_true") downstream.add_argument("--ignore-checkpoint-config", action="store_true") downstream.add_argument("--num-classes", type=int, default=None) downstream.add_argument("--head-type", choices=["transformer", "avgpool"], default=None) downstream.add_argument("--freeze-backbone", action="store_true") downstream.add_argument("--lr-backbone", type=float, default=None) downstream.add_argument("--lr-head", type=float, default=None) return parser.parse_args() def main(): args = parse_args() cfg = load_config(args.config) apply_common(cfg, args) apply_model(cfg, args) if args.command == "pretrain": if args.mask_ratio is not None: cfg.training.mask_ratio = args.mask_ratio if args.grad_clip is not None: cfg.training.grad_clip = args.grad_clip if args.dry_run: model = build_pretrain_model(cfg.model, torch.device("cpu")) print(json.dumps({"config": asdict(cfg), "parameters": sum(p.numel() for p in model.parameters())}, indent=2)) return Pretrainer(cfg).fit() elif args.command == "downstream": for key in ["csv", "id_column", "label_column", "label_mode", "path_id_mode"]: value = getattr(args, key, None) if value is not None: setattr(cfg.data, key, value) for key in ["num_classes", "head_type", "freeze_backbone"]: value = getattr(args, key, None) if value is not None: setattr(cfg.model, key, value) if args.lr_backbone is not None: cfg.training.lr_backbone = args.lr_backbone if args.lr_head is not None: cfg.training.lr_head = args.lr_head if args.pretrain_checkpoint is not None: cfg.pretrain_checkpoint = args.pretrain_checkpoint if args.from_scratch: cfg.from_scratch = True if args.ignore_checkpoint_config: cfg.use_checkpoint_config = False if args.dry_run: model = build_downstream_model(cfg.model, torch.device("cpu"), checkpoint_path=cfg.pretrain_checkpoint, from_scratch=cfg.from_scratch, use_checkpoint_config=cfg.use_checkpoint_config) print(json.dumps({"config": asdict(cfg), "parameters": sum(p.numel() for p in model.parameters())}, indent=2)) return DownstreamTrainer(cfg).fit() if __name__ == "__main__": main()