| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def dice(x, y): |
| | intersect = np.sum(np.sum(np.sum(x * y))) |
| | y_sum = np.sum(np.sum(np.sum(y))) |
| | if y_sum == 0: |
| | return 0.0 |
| | x_sum = np.sum(np.sum(np.sum(x))) |
| | return 2 * intersect / (x_sum + y_sum) |
| |
|
| |
|
| | class AverageMeter(object): |
| | def __init__(self): |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def update(self, val, n=1): |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = np.where(self.count > 0, self.sum / self.count, self.sum) |
| |
|
| |
|
| | def distributed_all_gather( |
| | tensor_list, valid_batch_size=None, out_numpy=False, world_size=None, no_barrier=False, is_valid=None |
| | ): |
| | if world_size is None: |
| | world_size = torch.distributed.get_world_size() |
| | if valid_batch_size is not None: |
| | valid_batch_size = min(valid_batch_size, world_size) |
| | elif is_valid is not None: |
| | is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device) |
| | if not no_barrier: |
| | torch.distributed.barrier() |
| | tensor_list_out = [] |
| | with torch.no_grad(): |
| | if is_valid is not None: |
| | is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)] |
| | torch.distributed.all_gather(is_valid_list, is_valid) |
| | is_valid = [x.item() for x in is_valid_list] |
| | for tensor in tensor_list: |
| | gather_list = [torch.zeros_like(tensor) for _ in range(world_size)] |
| | torch.distributed.all_gather(gather_list, tensor) |
| | if valid_batch_size is not None: |
| | gather_list = gather_list[:valid_batch_size] |
| | elif is_valid is not None: |
| | gather_list = [g for g, v in zip(gather_list, is_valid_list) if v] |
| | if out_numpy: |
| | gather_list = [t.cpu().numpy() for t in gather_list] |
| | tensor_list_out.append(gather_list) |
| | return tensor_list_out |
| |
|