| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | import torch |
| | import torch.nn as nn |
| | from torch.optim import Adam |
| | from tqdm import tqdm |
| | import torch.nn.functional as F |
| |
|
| | |
| | from tensorboardX import SummaryWriter |
| | import numpy as np |
| |
|
| | |
| | |
| | from model import * |
| | import lovasz_losses as L |
| |
|
| | |
| | |
| | import sys |
| | import os |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | model_dir = './model/s3_net_model.pth' |
| | NUM_ARGS = 3 |
| | NUM_EPOCHS = 20000 |
| | BATCH_SIZE = 1024 |
| | LEARNING_RATE = "lr" |
| | BETAS = "betas" |
| | EPS = "eps" |
| | WEIGHT_DECAY = "weight_decay" |
| |
|
| | |
| | NUM_INPUT_CHANNELS = 3 |
| | NUM_OUTPUT_CHANNELS = 10 |
| | BETA = 0.01 |
| |
|
| | |
| | |
| | set_seed(SEED1) |
| |
|
| | |
| | |
| | def adjust_learning_rate(optimizer, epoch): |
| | lr = 1e-4 |
| | if epoch > 50000: |
| | lr = 2e-5 |
| | if epoch > 480000: |
| | |
| | lr = lr * (0.1 ** (epoch // 110000)) |
| | |
| | |
| | for param_group in optimizer.param_groups: |
| | param_group['lr'] = lr |
| |
|
| |
|
| | |
| | def train(model, dataloader, dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs): |
| | |
| | model.train() |
| | |
| | running_loss = 0.0 |
| | |
| | kl_avg_loss = 0.0 |
| | |
| | ce_avg_loss = 0.0 |
| |
|
| | counter = 0 |
| | |
| | num_batches = int(len(dataset)/dataloader.batch_size) |
| | for i, batch in tqdm(enumerate(dataloader), total=num_batches): |
| | |
| | counter += 1 |
| | |
| | scans = batch['scan'] |
| | scans = scans.to(device) |
| | intensities = batch['intensity'] |
| | intensities = intensities.to(device) |
| | angle_incidence = batch['angle_incidence'] |
| | angle_incidence = angle_incidence.to(device) |
| | labels = batch['label'] |
| | labels = labels.to(device) |
| |
|
| | batch_size = scans.size(0) |
| |
|
| | |
| | optimizer.zero_grad() |
| |
|
| | |
| | semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence) |
| | |
| | ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size) |
| | lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long)) |
| | lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum() |
| | |
| | loss = ce_loss + BETA*kl_loss + lovasz_loss |
| | |
| | loss.backward(torch.ones_like(loss)) |
| | optimizer.step() |
| | |
| | |
| | if torch.cuda.device_count() > 1: |
| | loss = loss.mean() |
| | ce_loss = ce_loss.mean() |
| | kl_loss = lovasz_loss.mean() |
| |
|
| | running_loss += loss.item() |
| | |
| | kl_avg_loss += lovasz_loss.item() |
| | |
| | ce_avg_loss += ce_loss.item() |
| |
|
| | |
| | if(i % 512 == 0): |
| | print('Epoch [{}/{}], Step[{}/{}], Loss: {:.4f}, CE_Loss: {:.4f}, Lovasz_Loss: {:.4f}' |
| | .format(epoch, epochs, i + 1, num_batches, loss.item(), ce_loss.item(), lovasz_loss.item())) |
| | |
| | train_loss = running_loss / counter |
| | train_kl_loss = kl_avg_loss / counter |
| | train_ce_loss = ce_avg_loss / counter |
| |
|
| | return train_loss, train_kl_loss, train_ce_loss |
| |
|
| | |
| | def validate(model, dataloader, dataset, device, ce_criterion, lovasz_criterion, class_weights): |
| | |
| | model.eval() |
| | |
| | running_loss = 0.0 |
| | |
| | kl_avg_loss = 0.0 |
| | |
| | ce_avg_loss = 0.0 |
| |
|
| | counter = 0 |
| | |
| | num_batches = int(len(dataset)/dataloader.batch_size) |
| | with torch.no_grad(): |
| | for i, batch in tqdm(enumerate(dataloader), total=num_batches): |
| | |
| | counter += 1 |
| | |
| | scans = batch['scan'] |
| | scans = scans.to(device) |
| | intensities = batch['intensity'] |
| | intensities = intensities.to(device) |
| | angle_incidence = batch['angle_incidence'] |
| | angle_incidence = angle_incidence.to(device) |
| | labels = batch['label'] |
| | labels = labels.to(device) |
| |
|
| | batch_size = scans.size(0) |
| |
|
| | |
| | semantic_scan, semantic_channels, kl_loss = model(scans, intensities, angle_incidence) |
| | |
| | ce_loss = ce_criterion(semantic_channels, labels.to(torch.long)).div(batch_size) |
| | lovasz_loss, _ = lovasz_criterion(semantic_channels, labels.to(torch.long)) |
| | lovasz_loss = lovasz_loss.mul(class_weights.to("cuda")).sum() |
| | |
| | loss = ce_loss + BETA*kl_loss + lovasz_loss |
| | |
| | if torch.cuda.device_count() > 1: |
| | loss = loss.mean() |
| | ce_loss = ce_loss.mean() |
| | kl_loss = lovasz_loss.mean() |
| |
|
| | running_loss += loss.item() |
| | |
| | kl_avg_loss += lovasz_loss.item() |
| | |
| | ce_avg_loss += ce_loss.item() |
| |
|
| | val_loss = running_loss / counter |
| | val_kl_loss = kl_avg_loss / counter |
| | val_ce_loss = ce_avg_loss / counter |
| |
|
| | return val_loss, val_kl_loss, val_ce_loss |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def main(argv): |
| | |
| | |
| | if(len(argv) != NUM_ARGS): |
| | print("usage: python train.py [MDL_PATH] [TRAIN_PATH] [DEV_PATH] [TRAIN_MASK_PATH] [DEV_MASK_PATH]") |
| | exit(-1) |
| |
|
| | |
| | mdl_path = argv[0] |
| | pTrain = argv[1] |
| | pDev = argv[2] |
| |
|
| | |
| | odir = os.path.dirname(mdl_path) |
| |
|
| | |
| | if not os.path.exists(odir): |
| | os.makedirs(odir) |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | print('...Start reading data...') |
| | |
| | |
| | train_dataset = VaeTestDataset(pTrain, 'train') |
| | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=4, \ |
| | shuffle=True, drop_last=True, pin_memory=True) |
| |
|
| | |
| | |
| | dev_dataset = VaeTestDataset(pDev, 'dev') |
| | dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, num_workers=2, \ |
| | shuffle=True, drop_last=True, pin_memory=True) |
| |
|
| | |
| | class_weights = np.array([2.514399, 1.4917144, 0.51608694, 0.659483, 1.0900991, 1.6461798, 0.32852992, 1.5633508, 0.9236576, 0.10251398]) |
| |
|
| | |
| | class_weights = torch.Tensor(class_weights) |
| | print("class weights: ", class_weights) |
| | class_weights.to(device) |
| | print('...Finish reading data...') |
| |
|
| | |
| | model = S3Net(input_channels=NUM_INPUT_CHANNELS, |
| | output_channels=NUM_OUTPUT_CHANNELS) |
| | |
| | model.to(device) |
| |
|
| | |
| | opt_params = { LEARNING_RATE: 0.001, |
| | BETAS: (.9,0.999), |
| | EPS: 1e-08, |
| | WEIGHT_DECAY: .001 } |
| | |
| | ce_criterion = nn.CrossEntropyLoss(reduction='sum', weight=class_weights) |
| | ce_criterion.to(device) |
| | lovasz_criterion = L.LovaszSoftmax(reduction='sum', ignore_index=0) |
| | lovasz_criterion.to(device) |
| | |
| | optimizer = Adam(model.parameters(), **opt_params) |
| |
|
| | |
| | epochs = NUM_EPOCHS |
| |
|
| | |
| | if os.path.exists(mdl_path): |
| | checkpoint = torch.load(mdl_path) |
| | model.load_state_dict(checkpoint['model']) |
| | optimizer.load_state_dict(checkpoint['optimizer']) |
| | start_epoch = checkpoint['epoch'] |
| | print('Load epoch {} success'.format(start_epoch)) |
| | else: |
| | start_epoch = 0 |
| | |
| | |
| | |
| | print('No trained models, restart training') |
| |
|
| | |
| | if torch.cuda.device_count() > 1: |
| | print("Let's use 2 of total", torch.cuda.device_count(), "GPUs!") |
| | |
| | model = nn.DataParallel(model) |
| | |
| | model.to(device) |
| |
|
| | |
| | writer = SummaryWriter('runs') |
| |
|
| | epoch_num = 0 |
| | for epoch in range(start_epoch+1, epochs): |
| | |
| | adjust_learning_rate(optimizer, epoch) |
| | |
| | |
| | |
| | train_epoch_loss, train_kl_epoch_loss, train_ce_epoch_loss = train( |
| | model, train_dataloader, train_dataset, device, optimizer, ce_criterion, lovasz_criterion, class_weights, epoch, epochs |
| | ) |
| | valid_epoch_loss, valid_kl_epoch_loss, valid_ce_epoch_loss = validate( |
| | model, dev_dataloader, dev_dataset, device, ce_criterion, lovasz_criterion, class_weights |
| | ) |
| | |
| | |
| | writer.add_scalar('training loss', |
| | train_epoch_loss, |
| | epoch) |
| | writer.add_scalar('training kl loss', |
| | train_kl_epoch_loss, |
| | epoch) |
| | writer.add_scalar('training ce loss', |
| | train_ce_epoch_loss, |
| | epoch) |
| |
|
| | writer.add_scalar('validation loss', |
| | valid_epoch_loss, |
| | epoch) |
| | writer.add_scalar('validation kl loss', |
| | valid_kl_epoch_loss, |
| | epoch) |
| | writer.add_scalar('validation ce loss', |
| | valid_ce_epoch_loss, |
| | epoch) |
| |
|
| | print('Train set: Average loss: {:.4f}'.format(train_epoch_loss)) |
| | print('Validation set: Average loss: {:.4f}'.format(valid_epoch_loss)) |
| | |
| | |
| | if(epoch % 2000 == 0): |
| | if torch.cuda.device_count() > 1: |
| | state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} |
| | else: |
| | state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} |
| | path='./model/model' + str(epoch) +'.pth' |
| | torch.save(state, path) |
| |
|
| | epoch_num = epoch |
| |
|
| | |
| | if torch.cuda.device_count() > 1: |
| | state = {'model':model.module.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num} |
| | else: |
| | state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch_num} |
| | torch.save(state, mdl_path) |
| |
|
| | |
| | |
| |
|
| | return True |
| | |
| | |
| |
|
| |
|
| | |
| | |
| | if __name__ == '__main__': |
| | main(sys.argv[1:]) |
| | |
| | |
| |
|