| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import math |
| | import os |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from monai import data, transforms |
| | from monai.data import NibabelReader |
| | from monai.transforms import MapTransform |
| |
|
| | |
| | class LoadNumpyd(MapTransform): |
| | def __init__(self, keys): |
| | super().__init__(keys) |
| |
|
| | def __call__(self, data): |
| | d = dict(data) |
| | for key in self.keys: |
| | d[key] = np.load(d[key]) |
| | d[key] = np.squeeze(d[key],axis=0) |
| | return d |
| |
|
| | class Sampler(torch.utils.data.Sampler): |
| | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True): |
| | if num_replicas is None: |
| | if not torch.distributed.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | num_replicas = torch.distributed.get_world_size() |
| | if rank is None: |
| | if not torch.distributed.is_available(): |
| | raise RuntimeError("Requires distributed package to be available") |
| | rank = torch.distributed.get_rank() |
| | self.shuffle = shuffle |
| | self.make_even = make_even |
| | self.dataset = dataset |
| | self.num_replicas = num_replicas |
| | self.rank = rank |
| | self.epoch = 0 |
| | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) |
| | self.total_size = self.num_samples * self.num_replicas |
| | indices = list(range(len(self.dataset))) |
| | self.valid_length = len(indices[self.rank : self.total_size : self.num_replicas]) |
| |
|
| | def __iter__(self): |
| | if self.shuffle: |
| | g = torch.Generator() |
| | g.manual_seed(self.epoch) |
| | indices = torch.randperm(len(self.dataset), generator=g).tolist() |
| | else: |
| | indices = list(range(len(self.dataset))) |
| | if self.make_even: |
| | if len(indices) < self.total_size: |
| | if self.total_size - len(indices) < len(indices): |
| | indices += indices[: (self.total_size - len(indices))] |
| | else: |
| | extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices)) |
| | indices += [indices[ids] for ids in extra_ids] |
| | assert len(indices) == self.total_size |
| | indices = indices[self.rank : self.total_size : self.num_replicas] |
| | self.num_samples = len(indices) |
| | return iter(indices) |
| |
|
| | def __len__(self): |
| | return self.num_samples |
| |
|
| | def set_epoch(self, epoch): |
| | self.epoch = epoch |
| |
|
| |
|
| | def datafold_read(datalist, basedir, fold=0, key="training"): |
| | with open(datalist) as f: |
| | json_data = json.load(f) |
| |
|
| | json_data = json_data[key] |
| |
|
| | for d in json_data: |
| | for k, v in d.items(): |
| | if isinstance(d[k], list): |
| | d[k] = [os.path.join(basedir, iv) for iv in d[k]] |
| | elif isinstance(d[k], str): |
| | d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k] |
| | tr = [] |
| | val = [] |
| | for d in json_data: |
| | if "fold" in d and d["fold"] == fold: |
| | val.append(d) |
| | else: |
| | tr.append(d) |
| | return tr, val |
| |
|
| |
|
| | def get_loader(args): |
| | data_dir = args.data_dir |
| | datalist_json = args.json_list |
| | train_files, validation_files = datafold_read(datalist=datalist_json, basedir=data_dir, fold=args.fold) |
| | train_transform = transforms.Compose( |
| | [ |
| | transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()), |
| | LoadNumpyd(keys=["text_feature"]), |
| | transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), |
| | transforms.Resized(keys=["image","label"],spatial_size=[args.roi_x,args.roi_y,args.roi_z]), |
| | transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), |
| | transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0), |
| | transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0), |
| | transforms.ToTensord(keys=["image", "label", "text_feature"]), |
| | ] |
| | ) |
| | val_transform = transforms.Compose( |
| | [ |
| | transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()), |
| | LoadNumpyd(keys=["text_feature"]), |
| | transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), |
| | transforms.Resized(keys=["image", "label"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]), |
| | transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), |
| | transforms.ToTensord(keys=["image", "label", "text_feature"]), |
| | ] |
| | ) |
| |
|
| | test_transform = transforms.Compose( |
| | [ |
| | transforms.LoadImaged(keys=["image", "label"],reader=NibabelReader()), |
| | LoadNumpyd(keys=["text_feature"]), |
| | transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), |
| | transforms.Resized(keys=["image", "label"], spatial_size=[args.roi_x, args.roi_y, args.roi_z]), |
| | transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), |
| | transforms.ToTensord(keys=["image", "label", "text_feature"]), |
| | ] |
| | ) |
| |
|
| | if args.test_mode: |
| | val_ds = data.Dataset(data=validation_files, transform=test_transform) |
| | val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None |
| | test_loader = data.DataLoader( |
| | val_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True |
| | ) |
| |
|
| | loader = test_loader |
| | else: |
| | train_ds = data.Dataset(data=train_files, transform=train_transform) |
| |
|
| | train_sampler = Sampler(train_ds) if args.distributed else None |
| | train_loader = data.DataLoader( |
| | train_ds, |
| | batch_size=args.batch_size, |
| | shuffle=(train_sampler is None), |
| | num_workers=args.workers, |
| | sampler=train_sampler, |
| | pin_memory=True, |
| | ) |
| | val_ds = data.Dataset(data=validation_files, transform=val_transform) |
| | val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None |
| | val_loader = data.DataLoader( |
| | val_ds, batch_size=1, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True |
| | ) |
| | loader = [train_loader, val_loader] |
| |
|
| | return loader |
| |
|