YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

MS3SEG: Pre-trained Models for MS Lesion Segmentation

Paper Dataset Code License: CC BY 4.0

Pre-trained deep learning models for Multiple Sclerosis lesion segmentation from the MS3SEG dataset.

Note: These are representative models from Fold 4 of our 5-fold cross-validation. Complete training code and all fold results are available in our GitHub repository.


πŸ“‹ Repository Contents

MS3SEG/
β”œβ”€β”€ kfold_brain_segmentation_20250924_232752_unified_focal_loss/models/
β”‚   β”œβ”€β”€ binary_abnormal_wmh/      # Binary MS lesion segmentation
β”‚   β”‚   β”œβ”€β”€ u-net_fold_4_best.h5        
β”‚   β”‚   β”œβ”€β”€ unet++_fold_4_best.h5       
β”‚   β”‚   β”œβ”€β”€ unetr_fold_4_best.h5       
β”‚   β”‚   └── swinunetr_fold_4_best.h5    
β”‚   β”‚
β”‚   β”œβ”€β”€ binary_ventricles/        # Binary ventricle segmentation
β”‚   β”‚   β”œβ”€β”€ u-net_fold_4_best.h5        
β”‚   β”‚   β”œβ”€β”€ unet++_fold_4_best.h5       
β”‚   β”‚   β”œβ”€β”€ unetr_fold_4_best.h5       
β”‚   β”‚   └── swinunetr_fold_4_best.h5     
β”‚   β”‚
β”‚   └── multi_class/              # 4-class tri-mask segmentation
β”‚   β”‚   β”œβ”€β”€ u-net_fold_4_best.h5        
β”‚   β”‚   β”œβ”€β”€ unet++_fold_4_best.h5       
β”‚   β”‚   β”œβ”€β”€ unetr_fold_4_best.h5       
β”‚   β”‚   └── swinunetr_fold_4_best.h5  
β”‚
β”œβ”€β”€ figures/
β”‚   β”œβ”€β”€ training_curves/          # Loss and metrics across epochs
β”‚   └── sample_predictions/       # Visual results from paper
β”‚
β”œβ”€β”€ config/
β”‚   └── experiment_config.json    # Model training configuration
└── README.md                     # This file

Total Size: ~1.2 GB (12 model files)


🎯 Model Overview

Segmentation Scenarios

Scenario Classes Description
Multi-class 4 Background, Ventricles, Normal WMH, Abnormal WMH (MS lesions)
Binary Lesion 2 MS lesions vs. everything else
Binary Ventricle 2 Ventricles vs. everything else

Model Architectures

  • U-Net: Classic encoder-decoder with skip connections
  • U-Net++: Nested skip pathways for improved feature propagation
  • UNETR: Vision Transformer encoder with CNN decoder
  • Swin UNETR: Hierarchical shifted-window attention

All models trained on 256Γ—256 axial FLAIR images from 64 patients (Fold 4 training set).


πŸ“Š Performance (Fold 4 Validation Results)

Multi-Class Segmentation (Dice Score)

Model Ventricles Normal WMH Abnormal WMH Mean
U-Net 0.8967 0.5935 0.6709 0.7204
U-Net++ 0.8904 0.5881 0.6512 0.7099
UNETR 0.8401 0.4692 0.6632 0.6575
Swin UNETR 0.8608 0.5203 0.5920 0.6577

Binary Lesion Segmentation

Model Dice IoU HD95 (mm)
U-Net 0.7407 0.5882 32.64
U-Net++ 0.5930 0.4215 35.12
UNETR 0.6632 0.4963 40.85
Swin UNETR 0.5841 0.4127 38.19

Binary Ventricle Segmentation

Model Dice IoU HD95 (mm)
U-Net 0.8967 0.8130 9.52
U-Net++ 0.8904 0.8026 10.18
Swin UNETR 0.8608 0.7560 12.73
UNETR 0.8401 0.7240 14.92

Results are from validation set of Fold 4. See paper for complete 5-fold statistics.


πŸš€ Quick Start

Installation

pip install tensorflow>=2.10.0 nibabel numpy

Load and Use Models

from tensorflow import keras
from huggingface_hub import hf_hub_download
import numpy as np

# Download model
model_path = hf_hub_download(
    repo_id="Bawil/MS3SEG",
    filename="models/multi_class/U-Net_fold4.h5"
)

# Load model
model = keras.models.load_model(model_path, compile=False)

# Prepare your data (256x256 FLAIR image)
# image shape: (batch, 256, 256, 1)
predictions = model.predict(image)

# For multi-class: get class labels
pred_classes = np.argmax(predictions, axis=-1)
# Classes: 0=background, 1=ventricles, 2=normal WMH, 3=abnormal WMH

# For binary: apply threshold
pred_binary = (predictions > 0.5).astype(np.uint8)

Download All Models for One Scenario

from huggingface_hub import snapshot_download

# Download entire scenario folder
snapshot_download(
    repo_id="Bawil/MS3SEG",
    allow_patterns="models/multi_class/*",
    local_dir="./ms3seg_models"
)

πŸ“ Input Requirements

  • Format: NIfTI (.nii.gz) or NumPy array
  • Modality: T2-FLAIR (axial plane)
  • Dimensions: 256 Γ— 256 pixels
  • Channels: 1 (grayscale)
  • Preprocessing:
    • Co-registered to FLAIR space
    • Brain-extracted
    • Intensity normalized to [0, 1]
    • Voxel spacing: ~0.9 Γ— 0.9 Γ— 5.7 mmΒ³

See preprocessing scripts in our GitHub repository.


πŸ“– Dataset Information

MS3SEG is a Multiple Sclerosis MRI dataset with unique tri-mask annotations:

  • 100 patients from Iranian cohort (1.5T Toshiba scanner)
  • ~2000 annotated slices with expert consensus
  • 4 annotation classes: Background, Ventricles, Normal WMH, Abnormal WMH
  • Multiple sequences: T1w, T2w, T2-FLAIR (axial + sagittal)

Dataset Access: Figshare Repository (CC-BY-4.0 License)


πŸ”§ Model Training Details

All models were trained with:

  • Loss Function: Unified Focal Loss (combining Dice and Focal components)
  • Optimizer: Adam (lr=1e-4)
  • Batch Size: 4
  • Epochs: 100 (with early stopping, patience=10)
  • Data Split: 64 train / 16 validation patients (Fold 4)
  • Framework: TensorFlow 2.10+

Complete training configuration available in config.json.


πŸ“š Citation

If you use these models in your research, please cite our paper:

@article{bashiri2026ms3seg,
  title={A Multiple Sclerosis MRI Dataset with Tri-Mask Annotations for Lesion Segmentation},
  author={Bashiri Bawil, Mahdi and Shamsi, Mousa and Ghalehasadi, Aydin and Jafargholkhanloo, Ali Fahmi and Shakeri Bavil, Abolhassan},
  journal={Scientific Data},
  year={2026},
  doi={10.6084/m9.figshare.30393475},
  publisher={Nature Publishing Group}
}

πŸ”— Resources


⚠️ Important Notes

  1. Fold 4 Only: These models represent one fold (Fold 4) from our 5-fold cross-validation. They demonstrate representative performance but should not be considered the final "best" models across all folds.

  2. Research Use: These models are provided for research purposes. Clinical validation is required before any diagnostic application.

  3. Data Compatibility: Models expect preprocessed data matching our pipeline. See preprocessing documentation.

  4. Complete Results: For all 5 folds and comprehensive evaluation, see our GitHub repository and paper.

  5. Storage Considerations: Full 5-fold model collection (38GB) is available upon request. These representative Fold 4 models (6GB) are sufficient for most use cases.


πŸ“œ License

Models: CC-BY-4.0 (same as dataset)
Code: MIT License (see GitHub)

You are free to use, modify, and distribute these models with appropriate attribution.


πŸ™ Acknowledgments

Data acquired at Golgasht Medical Imaging Center, Tabriz, Iran. Ethics approval: Tabriz University of Medical Sciences (IR.TBZMED.REC.1402.902).


Made by the MS3SEG Team

GitHub β€’ Paper β€’ Dataset

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support