| import sys |
| sys.path.insert(0, './pytorch-image-models-main') |
|
|
| |
| from moe import Moe,all_loss |
| |
|
|
| import os |
| os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7" |
|
|
| import torch |
| import cv2 |
| from albumentations.pytorch import ToTensorV2 |
| from albumentations import ( |
| HorizontalFlip, VerticalFlip, ShiftScaleRotate, CLAHE, RandomRotate90, |
| Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue, |
| GaussNoise, MotionBlur, MedianBlur, PiecewiseAffine, RandomResizedCrop, |
| RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, CoarseDropout, |
| ShiftScaleRotate, CenterCrop, Resize, SmallestMaxSize |
| ) |
| import time |
| |
| import torch.multiprocessing as mp |
| import torch.distributed as dist |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts |
| from torch.cuda.amp import autocast, GradScaler |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import Adam, SGD, AdamW, RMSprop |
| from torch import nn |
| import random |
| from tqdm import tqdm |
| from PIL import Image |
| import numpy as np |
| import logging |
| from sklearn.model_selection import GroupKFold, StratifiedKFold |
| import pandas as pd |
| import math |
|
|
| CFG = { |
| 'seed': 42, |
| 'model_arch': 'convnext_large_mlp', |
|
|
| |
|
|
| 'patch': 16, |
| |
| 'mean':[0.485, 0.456, 0.406] , |
| 'std':[0.229, 0.224, 0.225], |
|
|
|
|
| 'mix_type': 'cutmix', |
| 'mix_prob': 0.7, |
|
|
| 'img_size': 512, |
|
|
| 'class_num': 1784, |
|
|
| 'warmup_epochs': 1, |
| 'warmup_lr_factor': 0.01, |
| 'epochs': 11, |
| |
| 'train_bs': 24, |
| 'valid_bs': 64, |
|
|
| 'lr': 7.5e-5, |
| 'min_lr': 1e-5, |
|
|
| 'differLR': False, |
| |
| 'head_lr': 0, |
| 'head_wd': 0.05, |
| 'num_workers': 8, |
| 'device': 'cuda', |
| 'smoothing': 0.1, |
|
|
| 'weight_decay': 2e-5, |
| 'accum_iter': 1, |
| 'verbose_step': 1, |
| |
| } |
|
|
| logger = logging.getLogger(__name__) |
| logger.setLevel(level=logging.INFO) |
| handler = logging.FileHandler(f"logs/{CFG['model_arch']}_train_moe.log") |
| handler.setLevel(logging.INFO) |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| handler.setFormatter(formatter) |
| logger.addHandler(handler) |
|
|
|
|
| def seed_everything(seed): |
| random.seed(seed) |
| os.environ['PYTHONHASHSEED'] = str(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = True |
|
|
|
|
| def get_img(path): |
| |
| im_bgr = cv2.imread(path) |
| im_rgb = im_bgr[:, :, ::-1] |
| return im_rgb |
|
|
|
|
| train_data_root = '/data1/dataset/SnakeCLEF2024/' |
| val_data_root = '/data1/dataset/SnakeCLEF2023/val/SnakeCLEF2023-large_size/' |
| train_df = pd.read_csv('./metadata/train_full.csv') |
|
|
| valid_df = pd.read_csv('./metadata/SnakeCLEF2023-ValMetadata.csv') |
| is_venomous_df = pd.read_csv('./metadata/venomous_status_list.csv') |
| class_id2venomous ={} |
| venomous_mask = torch.ones(CFG['class_num']) |
| for class_id,is_venomous in zip(is_venomous_df['class_id'],is_venomous_df['MIVS']): |
| venomous_mask[class_id]=is_venomous |
| if class_id not in class_id2venomous.keys(): |
| class_id2venomous[class_id] = is_venomous |
| train_df['MIVS'] = train_df['class_id'].map(class_id2venomous) |
| valid_df['MIVS'] = valid_df['class_id'].map(class_id2venomous) |
|
|
| class FGVCDataset(Dataset): |
| def __init__(self, df, data_root, |
| transforms=None, |
| output_label=True, |
| one_hot_label=False |
| ): |
|
|
| super().__init__() |
| self.df = df.reset_index(drop=True).copy() |
| self.transforms = transforms |
| self.data_root = data_root |
|
|
| self.output_label = output_label |
| self.one_hot_label = one_hot_label |
|
|
| if output_label == True: |
| self.labels = self.df['class_id'].values |
| self.is_venomous = self.df['MIVS'] |
| if one_hot_label is True: |
| self.labels = np.eye(self.df['class_id'].max() + 1)[self.labels] |
|
|
| def __len__(self): |
| return self.df.shape[0] |
|
|
| def __getitem__(self, index: int): |
| |
| if self.output_label: |
| target = self.labels[index] |
| venomous = self.is_venomous[index] |
|
|
| image_path = self.data_root + self.df.loc[index]['image_path'] |
|
|
|
|
| img = get_img(image_path) |
|
|
| if self.transforms: |
| img = self.transforms(image=img)['image'] |
|
|
| if self.output_label == True: |
| return img, target,venomous |
| else: |
| return img |
|
|
|
|
| def get_train_transforms(): |
| return Compose([ |
| RandomResizedCrop(CFG['img_size'], CFG['img_size'], |
| interpolation=cv2.INTER_CUBIC, scale=(0.5, 1.3)), |
| Transpose(p=0.5), |
| HorizontalFlip(p=0.5), |
| VerticalFlip(p=0.5), |
| ShiftScaleRotate(p=0.3), |
| PiecewiseAffine(p=0.5), |
| RandomBrightnessContrast( |
| brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=1.0), |
| OneOf([ |
| OpticalDistortion(distort_limit=1.0), |
| GridDistortion(num_steps=5, distort_limit=1.), |
|
|
| ], p=0.5), |
|
|
| Normalize(mean=CFG['mean'], std=CFG['std'], |
| max_pixel_value=255.0, p=1.0), |
| ToTensorV2(p=1.0), |
| ], p=1.) |
|
|
|
|
|
|
| def get_valid_transforms(): |
| return Compose([ |
| |
| Resize(CFG['img_size'], CFG['img_size'], |
| interpolation=cv2.INTER_CUBIC), |
| |
| Normalize(mean=CFG['mean'], std=CFG['std'], |
| max_pixel_value=255.0, p=1.0), |
| ToTensorV2(p=1.0), |
| ], p=1.) |
|
|
|
|
| def prepare_dataloader(train_df, val_df, train_idx, val_idx): |
| train_ = train_df.loc[train_idx, :].reset_index(drop=True) |
| valid_ = val_df.loc[val_idx, :].reset_index(drop=True) |
|
|
| train_ds = FGVCDataset(train_, train_data_root, transforms=get_train_transforms()) |
| valid_ds = FGVCDataset(valid_, val_data_root, transforms=get_valid_transforms()) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_ds, |
| batch_size=CFG['train_bs'], |
| pin_memory=False, |
| drop_last=False, |
| shuffle=True, |
| num_workers=CFG['num_workers'] |
| ) |
| val_loader = torch.utils.data.DataLoader( |
| valid_ds, |
| batch_size=CFG['valid_bs'], |
| num_workers=CFG['num_workers'], |
| shuffle=False, |
| pin_memory=False, |
| ) |
| return train_loader, val_loader |
|
|
| def rand_bbox(size, lam): |
| W = size[2] |
| H = size[3] |
| cut_rat = np.sqrt(1. - lam) |
| cut_w = np.int32(W * cut_rat) |
| cut_h = np.int32(H * cut_rat) |
|
|
| |
| cx = np.random.randint(W) |
| cy = np.random.randint(H) |
|
|
| bbx1 = np.clip(cx - cut_w // 2, 0, W) |
| bby1 = np.clip(cy - cut_h // 2, 0, H) |
| bbx2 = np.clip(cx + cut_w // 2, 0, W) |
| bby2 = np.clip(cy + cut_h // 2, 0, H) |
|
|
| return bbx1, bby1, bbx2, bby2 |
|
|
|
|
| def generate_mask_random(imgs, patch=CFG['patch'], mask_token_num_start=14, lam=0.5): |
| _, _, W, H = imgs.shape |
| assert W % patch == 0 |
| assert H % patch == 0 |
| p = W // patch |
|
|
| mask_ratio = 1 - lam |
| num_masking_patches = min(p**2, int(mask_ratio * (p**2)) + mask_token_num_start) |
| mask_idx = np.random.permutation(p**2)[:num_masking_patches] |
| lam = 1 - num_masking_patches / (p**2) |
| return mask_idx, lam |
|
|
|
|
| def get_mixed_data(imgs, image_labels, is_venomous,mix_type): |
| mix_lst = ['cutmix', 'tokenmix', 'mixup', 'randommix'] |
| assert mix_type in mix_lst, f'Not Supported mix type: {mix_type}' |
| if mix_type == 'randommix': |
| |
| mix_type = random.choice(mix_lst[:-2]) |
|
|
| if mix_type == 'mixup': |
| alpha = 2.0 |
| rand_index = torch.randperm(imgs.size()[0]).cuda() |
| target_a = image_labels |
| target_b = image_labels[rand_index] |
| lam = np.random.beta(alpha, alpha) |
| imgs = imgs * lam + imgs[rand_index] * (1 - lam) |
| elif mix_type == 'cutmix': |
| beta = 1.0 |
| lam = np.random.beta(beta, beta) |
| rand_index = torch.randperm(imgs.size()[0]).cuda() |
| target_a = image_labels |
| target_b = image_labels[rand_index] |
| is_venomous_a = is_venomous |
| is_venomous_b = is_venomous[rand_index] |
| bbx1, bby1, bbx2, bby2 = rand_bbox(imgs.size(), lam) |
| imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[rand_index, :, bbx1:bbx2, bby1:bby2] |
| |
| lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (imgs.size()[-1] * imgs.size()[-2])) |
| elif mix_type == 'tokenmix': |
| B, C, W, H = imgs.shape |
| mask_idx, lam = generate_mask_random(imgs) |
| rand_index = torch.randperm(imgs.size()[0]).cuda() |
| p = W // CFG['patch'] |
| patch_w = CFG['patch'] |
| patch_h = CFG['patch'] |
| for idx in mask_idx: |
| row_s = idx // p |
| col_s = idx % p |
| x1 = patch_w * row_s |
| x2 = x1 + patch_w |
| y1 = patch_h * col_s |
| y2 = y1 + patch_h |
| imgs[:, :, x1:x2, y1:y2] = imgs[rand_index, :, x1:x2, y1:y2] |
|
|
| target_a = image_labels |
| target_b = image_labels[rand_index] |
|
|
| return imgs, target_a, target_b, is_venomous_a,is_venomous_b,lam |
|
|
|
|
| def train_one_epoch_mix(epoch, model, loss_fn, optimizer, train_loader, device, scheduler=None, schd_batch_update=False, mix_type=CFG['mix_type']): |
| model.train() |
|
|
| running_loss = None |
| image_preds_all = [] |
| image_targets_all = [] |
|
|
| pbar = tqdm(enumerate(train_loader), total=len(train_loader),ncols=70) |
| for step, (imgs, image_labels,is_venomous) in pbar: |
| imgs = imgs.to(device).float() |
| image_labels = image_labels.to(device).long() |
| is_venomous = is_venomous.to(device).float() |
| |
| if np.random.rand(1) < CFG['mix_prob']: |
| imgs, target_a, target_b,is_venomous_a,is_venomous_b ,lam = get_mixed_data(imgs, image_labels, is_venomous,mix_type) |
| with autocast(): |
| |
| |
| |
| y_hat,expert_pred,alpha,image_preds = model(imgs) |
| loss = loss_fn(y_hat,expert_pred,alpha,image_preds,target_a,is_venomous_a)*lam+loss_fn(y_hat,expert_pred,alpha,image_preds,target_b,is_venomous_b)*(1.0-lam) |
| scaler.scale(loss).backward() |
| else: |
| with autocast(): |
| y_hat,expert_pred,alpha,image_preds = model(imgs) |
| loss = loss_fn(y_hat,expert_pred,alpha,image_preds,image_labels,is_venomous) |
| scaler.scale(loss).backward() |
| image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()] |
| image_targets_all += [image_labels.detach().cpu().numpy()] |
| if running_loss is None: |
| running_loss = loss.item() |
| else: |
| running_loss = running_loss * .99 + loss.item() * .01 |
| |
| |
| if ((step + 1) % CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)): |
| |
| |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
|
|
| if scheduler is not None and schd_batch_update: |
| scheduler.step() |
|
|
| if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)): |
| description = f'epoch {epoch} loss: {running_loss:.4f}' |
| pbar.set_description(description) |
|
|
| image_preds_all = np.concatenate(image_preds_all) |
| image_targets_all = np.concatenate(image_targets_all) |
| accuracy = (image_preds_all == image_targets_all).mean() |
|
|
| print('Train multi-class accuracy = {:.4f}'.format(accuracy)) |
| logger.info(' Epoch: ' + str(epoch) + ' Train multi-class accuracy = {:.4f}'.format(accuracy)) |
| logger.info(' Epoch: ' + str(epoch) + ' Train loss = {:.4f}'.format(running_loss)) |
|
|
| if scheduler is not None and not schd_batch_update: |
| scheduler.step() |
|
|
|
|
| def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False): |
| model.eval() |
|
|
| loss_sum = 0 |
| sample_num = 0 |
| image_preds_all = [] |
| image_targets_all = [] |
|
|
| pbar = tqdm(enumerate(val_loader), total=len(val_loader),ncols=70) |
| for step, (imgs, image_labels,is_venomous) in pbar: |
| imgs = imgs.to(device).float() |
| image_labels = image_labels.to(device).long() |
| is_venomous = is_venomous.to(device).float() |
| |
| y_hat,expert_pred,alpha,image_preds = model(imgs) |
| image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()] |
| image_targets_all += [image_labels.detach().cpu().numpy()] |
| |
| openset_idx = image_labels == -1 |
| image_labels[openset_idx] = 0 |
| loss = loss_fn(image_preds, image_labels) |
|
|
| loss_sum += loss.item() * image_labels.shape[0] |
| sample_num += image_labels.shape[0] |
|
|
| if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)): |
| description = f'epoch {epoch} loss: {loss_sum / sample_num:.4f}' |
| pbar.set_description(description) |
|
|
| image_preds_all = np.concatenate(image_preds_all) |
| image_targets_all = np.concatenate(image_targets_all) |
|
|
| accuracy = (image_preds_all == image_targets_all).mean() |
| print('validation multi-class accuracy = {:.4f}'.format(accuracy)) |
| logger.info(' Epoch: ' + str(epoch) + ' validation multi-class accuracy = {:.4f}'.format(accuracy)) |
|
|
| if scheduler is not None: |
| if schd_loss_update: |
| scheduler.step(loss_sum / sample_num) |
| else: |
| scheduler.step() |
| return accuracy |
|
|
|
|
|
|
| if __name__ == '__main__': |
| |
| seed_everything(CFG['seed']) |
| logger.info(CFG) |
|
|
| trn_idx = np.arange(train_df.shape[0]) |
| val_idx = np.arange(valid_df.shape[0]) |
|
|
| df_class_id = np.array(train_df['class_id']) |
| class_counts = np.bincount(df_class_id) |
| device = torch.device(CFG['device']) |
|
|
| |
| model = Moe(CFG['model_arch'],CFG['class_num'],venomous_mask) |
| model = nn.DataParallel(model) |
| model.to(device) |
| model.module.not_venomous_mask.to(device) |
| model.module.venomous_mask.to(device) |
| |
|
|
| train_loader, val_loader = prepare_dataloader(train_df, valid_df, trn_idx, val_idx) |
|
|
| scaler = GradScaler() |
|
|
|
|
| if CFG['differLR']: |
| backbone_params = list(map(id, model.module.backbone.parameters())) |
| head_params = filter(lambda p: id(p) not in backbone_params, model.parameters()) |
| |
| if CFG['head_lr']>0: |
| lr_cfg = [ {'params': model.module.backbone.parameters(), 'lr': CFG['lr'] ,'weight_decay':CFG['weight_decay']}, |
| {'params': head_params , 'lr': CFG['head_lr'],'weight_decay':CFG['head_wd']}] |
| optimizer = torch.optim.AdamW(lr_cfg, lr=CFG['lr'], weight_decay=CFG['weight_decay']) |
| else: |
| |
| print('frozen center') |
| |
| model.module.center.requires_grad = False |
| lr_cfg = [ |
| {'params': model.module.backbone.parameters(), 'lr': CFG['lr'], 'weight_decay': CFG['weight_decay']}] |
| |
| optimizer = torch.optim.AdamW(lr_cfg, lr=CFG['lr'], weight_decay=CFG['weight_decay']) |
|
|
|
|
| else: |
| optimizer = torch.optim.AdamW(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay']) |
|
|
|
|
| main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=CFG['epochs'] - CFG['warmup_epochs'], eta_min=CFG['min_lr'] |
| ) |
| warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( |
| optimizer, start_factor=CFG['warmup_lr_factor'], total_iters=CFG['warmup_epochs'] |
| ) |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[CFG['warmup_epochs']] |
| ) |
|
|
|
|
| loss_tr = all_loss(class_counts,CFG['class_num']).to(device) |
|
|
| loss_fn = nn.CrossEntropyLoss(label_smoothing=CFG['smoothing']).to(device) |
|
|
| best_acc = 0.0 |
| for epoch in range(CFG['epochs']): |
| print(optimizer.param_groups[0]['lr']) |
| train_one_epoch_mix(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler) |
| temp_acc = 0.0 |
| with torch.no_grad(): |
| temp_acc = valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False) |
| if temp_acc > best_acc: |
| torch.save(model.state_dict(), './checkpoints_moe/moe_{}_mix_{}_mixprob_{}_seed_{}_ls_{}_epochs_{}_differLR_{}_imsize{}.pth'.format( |
| CFG['model_arch'], |
| CFG['mix_type'], |
| CFG['mix_prob'], |
| CFG['seed'], |
| CFG['smoothing'], |
| CFG['epochs'], |
| CFG['differLR'], |
| CFG['img_size'])) |
| if temp_acc > best_acc: |
| best_acc = temp_acc |
|
|
| del model, optimizer, train_loader, val_loader, scaler, scheduler |
| print(best_acc) |
| logger.info('BEST-Valid-ACC: ' + str(best_acc)) |
| torch.cuda.empty_cache() |
|
|