| from attacks.CleanSheet.packet import * |
| import torch |
| from tqdm.auto import tqdm |
| from pynvml import * |
| from utils.data_manager import DataManager |
|
|
| def train(args_cl): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(device) |
| epochs = 100 |
| save_interval = 1 |
| temperature = 1.0 |
| alpha = 1.0 |
| epochs_per_validation = 5 |
| train_student_with_kd = True |
| pr = 0.1 |
| best_model_index = 0 |
| beta = 1.0 |
| target_class = args_cl['target_class'] |
|
|
| data_manager = DataManager( |
| args_cl["dataset"], |
| args_cl["shuffle"], |
| args_cl["seed"], |
| args_cl["init_cls"], |
| args_cl["increment"], |
| False |
| ) |
|
|
| clean_train_data = data_manager.get_dataset(np.arange(0, 10), source="train", mode="train") |
| print(len(clean_train_data)) |
| clean_train_dataloader = DataLoader(clean_train_data, batch_size=128, num_workers=0, pin_memory=True, shuffle=True) |
|
|
| clean_test_data = data_manager.get_dataset(np.arange(0, 10), source="test", mode="test") |
| print(len(clean_test_data)) |
| clean_test_dataloader = DataLoader(clean_test_data, batch_size=128, num_workers=0, pin_memory=True) |
|
|
| poison_train_data = PoisonDataset(clean_train_data, |
| np.random.choice(len(clean_train_data), int(pr * len(clean_train_data)), |
| replace=False), |
| target=target_class) |
| print(len(poison_train_data)) |
| poison_train_dataloader = DataLoader(poison_train_data, batch_size=128, num_workers=0, pin_memory=True, shuffle=True) |
|
|
| poison_test_data = PoisonDataset(clean_test_data, |
| np.random.choice(len(clean_test_data), len(clean_test_data), replace=False), |
| target=target_class) |
| print(len(poison_test_data)) |
| poison_test_dataloader = DataLoader(poison_test_data, batch_size=128, num_workers=0, pin_memory=True) |
|
|
| |
| teacher = resnet34(num_classes=10) |
| teacher.to(device) |
| teacher_optimizer = optim.SGD(teacher.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) |
| teacher_scheduler = lr_scheduler.CosineAnnealingLR(teacher_optimizer, T_max=100) |
| teacher.eval() |
|
|
| teacher_lambda_t = 1e-1 |
| teacher_lambda_mask = 1e-4 |
| teacher_trainable_when_training_trigger = False |
|
|
| |
| student1 = resnet18(num_classes=10) |
| student1.to(device) |
| student1_optimizer = optim.SGD(student1.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) |
| student1_scheduler = lr_scheduler.CosineAnnealingLR(student1_optimizer, T_max=100) |
| student1.eval() |
|
|
| |
| student2 = vgg16(num_classes=10) |
| student2.to(device) |
| student2_optimizer = optim.SGD(student2.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) |
| student2_scheduler = lr_scheduler.CosineAnnealingLR(student2_optimizer, T_max=100) |
| student2.eval() |
|
|
| |
| student3 = mobilenet_v2(num_classes=10) |
| student3.to(device) |
| student3_optimizer = optim.SGD(student3.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4) |
| student3_scheduler = lr_scheduler.CosineAnnealingLR(student3_optimizer, T_max=100) |
| student3.eval() |
|
|
| student_lambda_t = 1e-2 |
| student_lambda_mask = 1e-4 |
| student_trainable_when_training_trigger = False |
|
|
| |
| tri = Trigger(size=32).to(device) |
| trigger_optimizer = optim.Adam(tri.parameters(), lr=1e-2) |
|
|
| print("Start generate triggers") |
| tri.train() |
| models = [teacher, student1, student2, student3] |
|
|
| for epoch in range(epochs): |
| masks = [] |
| triggers = [] |
| best_model = models[best_model_index] |
|
|
| print('epoch: {}'.format(epoch)) |
| for index, model in enumerate(models): |
| if index == best_model_index: |
| print('train teacher network with clean data') |
| model.train() |
| model.to(device) |
| for _, x, y in tqdm(clean_train_dataloader): |
| x = x.to(device) |
| y = y.to(device) |
| logits = model(x) |
| loss = F.cross_entropy(logits, y.to(torch.int64)) |
| optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() |
| loss.backward() |
| optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() |
|
|
| print('train trigger for teacher network with poison data') |
| model.eval() |
| tri.train() |
| model.to(device) |
| tri.to(device) |
| for x, y in tqdm(poison_train_dataloader): |
| x = x.to(device) |
| y = y.to(device) |
| x = tri(x) |
| logits = model(x) |
| loss = teacher_lambda_t * F.cross_entropy(logits, y) + teacher_lambda_mask * torch.norm(tri.mask, p=2) |
| optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() |
| trigger_optimizer.zero_grad() |
| loss.backward() |
| trigger_optimizer.step() |
| if teacher_trainable_when_training_trigger: |
| optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() |
|
|
| with torch.no_grad(): |
| tri.mask.clamp_(0, 1) |
| tri.trigger.clamp_(-1*beta, 1*beta) |
| masks.append(tri.mask.clone()) |
| triggers.append(tri.trigger.clone()) |
| else: |
| |
| best_model.eval() |
| model.train() |
| best_model.to(device) |
| model.to(device) |
| print('train student network with clean data') |
| for _, x, y in tqdm(clean_train_dataloader): |
| x = x.to(device) |
| y = y.to(device) |
| student_logits = model(x) |
| with torch.no_grad(): |
| teacher_logits = best_model(x) |
| soft_loss = F.kl_div(F.log_softmax(student_logits / temperature, |
| dim=1), |
| F.softmax(teacher_logits / temperature, |
| dim=1), |
| reduction='batchmean') |
| hard_loss = F.cross_entropy(student_logits, y.to(torch.int64)) |
| loss = alpha * soft_loss + (1 - alpha) * hard_loss |
| optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() |
| loss.backward() |
| optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() |
|
|
| print(' train trigger for student network with poison data') |
| model.eval() |
| tri.train() |
| model.to(device) |
| tri.to(device) |
| for x, y in tqdm(poison_train_dataloader): |
| x = x.to(device) |
| y = y.to(device) |
| x = tri(x) |
| logits = student1(x) |
| loss = student_lambda_t * F.cross_entropy(logits, y) + student_lambda_mask * torch.norm(tri.mask, p=2) |
| optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).zero_grad() |
| trigger_optimizer.zero_grad() |
| loss.backward() |
| trigger_optimizer.step() |
|
|
| if student_trainable_when_training_trigger: |
| optim.SGD(model.parameters(), lr=0.2, momentum=0.9, weight_decay=5e-4).step() |
|
|
| with torch.no_grad(): |
| tri.mask.clamp_(0, 1) |
| tri.trigger.clamp_(-1*beta, 1*beta) |
| masks.append(tri.mask.clone()) |
| triggers.append(tri.trigger.clone()) |
|
|
| average_mask = torch.mean(torch.stack(masks), dim=0) |
| average_trigger = torch.mean(torch.stack(triggers), dim=0) |
| tri.mask.data = average_mask |
| tri.trigger.data = average_trigger |
|
|
| teacher_scheduler.step() |
| student1_scheduler.step() |
| student2_scheduler.step() |
| student3_scheduler.step() |
|
|
| |
| accuracies = [] |
|
|
| for model in models: |
|
|
| model.eval() |
| model.to(device) |
| with torch.no_grad(): |
| total = 0 |
| correct = 0 |
| for _, x, y in tqdm(clean_test_dataloader): |
| x = x.to(device) |
| y = y.to(device).to(torch.int64) |
| logits = model(x) |
| _, predict_label = logits.max(1) |
| total += y.size(0) |
| correct += predict_label.eq(y).sum().item() |
| accuracy = correct / total |
| accuracies.append(accuracy) |
|
|
| best_model_index = np.argmax(accuracies) |
|
|
| print("--------Validation accuracy of 4 models(clean_test_dataloader)---------") |
| print(accuracies) |
| print("--------Selected as the index for the teacher model---------") |
| print(best_model_index) |
|
|
| ASR = [] |
|
|
| for model in models: |
|
|
| model.eval() |
| model.to(device) |
| with torch.no_grad(): |
| total = 0 |
| correct = 0 |
| for x, y in tqdm(poison_test_dataloader): |
| x = x.to(device) |
| x = tri(x) |
| y = y.to(device) |
| logits = model(x) |
| _, predict_label = logits.max(1) |
| total += y.size(0) |
| correct += predict_label.eq(y).sum().item() |
| asr = correct / total |
| ASR.append(asr) |
|
|
| print("--------The attack success rate of 4 models(poison_test_dataloader)---------") |
| print(ASR) |
|
|
| |
| if epoch == 0 or (epoch + 1) % save_interval == 0: |
| trigger_p = '{}/Baseline_Trigger/{}/epoch_{}.pth'.format(args_cl['logs_eval_name'], target_class, epoch) |
| os.makedirs(os.path.dirname(trigger_p), exist_ok=True) |
| torch.save(tri.state_dict(), trigger_p) |
|
|
|
|
|
|