# This implements the workflow for applying the network to a directory of images and saving the predicted segmentations. imports: - $import os - $import torch - $import glob # pull out some constants from MONAI image: $monai.utils.CommonKeys.IMAGE pred: $monai.utils.CommonKeys.PRED # hyperparameters for you to modify on the command line batch_size: 1 # number of images per batch num_workers: 0 # number of workers to generate batches with num_classes: 4 # number of classes in training data which network should predict device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # define various paths bundle_root: . # root directory of the bundle ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting dataset_dir: $@bundle_root + '/test_data' # where data is coming from output_dir: './outputs' # directory to store images to # network definition, this could be parameterised by pre-defined values or on the command line network_def: _target_: UNet spatial_dims: 3 in_channels: 1 out_channels: '@num_classes' channels: [8, 16, 32, 64] strides: [2, 2, 2] num_res_units: 2 network: $@network_def.to(@device) # list all niftis in the input directory file_pattern: '*.nii*' data_list: '$list(sorted(glob.glob(os.path.join(@dataset_dir, @file_pattern))))' # collect data dictionaries for all files data_dicts: '$[{@image:i} for i in @data_list]' # these transforms are used for inference to load and regularise inputs transforms: - _target_: LoadImaged keys: '@image' image_only: true - _target_: EnsureChannelFirstd keys: '@image' - _target_: ScaleIntensityd keys: '@image' preprocessing: _target_: Compose transforms: $@transforms dataset: _target_: Dataset data: '@data_dicts' transform: '@preprocessing' dataloader: _target_: ThreadDataLoader # generate data ansynchronously from inference dataset: '@dataset' batch_size: '@batch_size' num_workers: '@num_workers' # should be replaced with other inferer types if training process is different for your network inferer: _target_: SimpleInferer # transform to apply to data from network to be suitable for loss function and validation postprocessing: _target_: Compose transforms: - _target_: Activationsd keys: '@pred' softmax: true - _target_: AsDiscreted keys: '@pred' argmax: true - _target_: SaveImaged keys: '@pred' meta_keys: pred_meta_dict data_root_dir: '@dataset_dir' output_dir: '@output_dir' dtype: $None output_dtype: $None output_postfix: '' resample: false separate_folder: true # inference handlers to load checkpoint, gather statistics handlers: - _target_: CheckpointLoader _disabled_: $not os.path.exists(@ckpt_path) load_path: '@ckpt_path' load_dict: model: '@network' - _target_: StatsHandler name: null # use engine.logger as the Logger object to log to output_transform: '$lambda x: None' # engine for running inference, ties together objects defined above and has metric definitions evaluator: _target_: SupervisedEvaluator device: '@device' val_data_loader: '@dataloader' network: '@network' inferer: '@inferer' postprocessing: '@postprocessing' val_handlers: '@handlers' run: - $@evaluator.run()