| import os
|
| from foolbox import PyTorchModel, accuracy
|
| from foolbox.attacks.base import *
|
| from foolbox.attacks.gradient_descent_base import *
|
| from torchvision import transforms
|
| from utils.data_manager import DataManager, get_dataloader
|
| import torch
|
| import logging
|
| import eagerpy as ep
|
| from utils.data_manager import load_all_task_models
|
|
|
|
|
|
|
| class SustainableAttack(Attack):
|
| def __init__(self, args, device='cuda'):
|
| super().__init__()
|
| self.device = device
|
| self.args = args
|
|
|
|
|
| self.data_manager = DataManager(
|
| args["dataset"],
|
| args["shuffle"],
|
| args["seed"],
|
| args["init_cls"],
|
| args["increment"],
|
| args["attack"]
|
| )
|
| self.args['target_class_list'] = self.data_manager._class_order[:self.data_manager._increments[0]]
|
| self.args['target_class_dict'] = dict(zip(self.args['target_class_list'], range(len(self.args['target_class_list']))))
|
| self.img_s = 32 if args["dataset"] == 'cifar100' else 224
|
| self.batch_size = args['batch_size']
|
| self.loader = get_dataloader(self.data_manager, batch_size=self.batch_size,
|
| start_class=0, end_class=10,
|
| train=True, shuffle=True, num_workers=0)
|
|
|
| ckpts = sorted([f for f in os.listdir(args['logs_name']) if f.endswith('.pkl')])
|
| self.ckpt_paths = [os.path.join(args['logs_name'], ckpt_file) for ckpt_file in ckpts]
|
| self.model = None
|
| self.model0 = None
|
| self.attack = None
|
|
|
| self.target_class = args['target_class']
|
| if args["dataset"] == "cifar100":
|
| self.norm = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
|
| std=[0.2675, 0.2565, 0.2761])
|
| self.preprocessing = dict(mean=[0.5071, 0.4867, 0.4408],
|
| std=[0.2675, 0.2565, 0.2761], axis=-3)
|
| else:
|
| self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225])
|
| self.preprocessing = dict(mean=[0.485, 0.456, 0.406],
|
| std=[0.229, 0.224, 0.225], axis=-3)
|
|
|
| def run_attack(self):
|
| pass
|
|
|
|
|
| def to_alls(self, imgs, labels, labels_t=None,
|
| target_imgs=None, target_labels=None, return_index=False):
|
| correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
|
| correct_index_t = ep.full_like(ep.astensors(torch.ones((len(target_imgs),), dtype=bool, device=self.device))[0], fill_value=True)
|
| models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
|
| batch_size=self.batch_size,
|
| train=True,
|
| load_type='model')[0]
|
| for task in range(len(models)):
|
| model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
|
| acc_bool = accuracy(model, imgs, labels)[1]
|
| if task == 0:
|
| acc_bool_t, target_logits = accuracy(model, target_imgs, target_labels)[1:]
|
| else:
|
| acc_bool_t = accuracy(model, target_imgs, target_labels)[1]
|
| correct_index = ep.logical_and(correct_index, acc_bool)
|
| correct_index_t = ep.logical_and(correct_index_t, acc_bool_t)
|
| del model, acc_bool, acc_bool_t
|
| if correct_index.any():
|
| imgs = imgs[correct_index]
|
| labels = labels[correct_index]
|
| if self.target_class is not None:
|
| labels_t = labels_t[correct_index]
|
| logging.info(
|
| f"Filtering {len(labels)} Correct samples for all CL models.")
|
| else:
|
| print("No valid samples found for IMGS, skipping this batch.")
|
| imgs, labels, labels_t = None, None, None
|
|
|
| if correct_index_t.any():
|
| target_imgs = target_imgs[correct_index_t]
|
| target_labels = target_labels[correct_index_t]
|
| target_logits = target_logits[correct_index_t]
|
| logging.info(
|
| f"Filtering {len(target_labels)} Target samples for all CL models.")
|
| else:
|
| logging.info("No valid samples found for TARGET IMGS, skipping this batch.")
|
| target_imgs, target_labels, target_logits = None, None, None
|
| if return_index:
|
| return correct_index, correct_index_t
|
| del models, correct_index, correct_index_t
|
| return imgs, labels, labels_t, target_imgs, target_labels, target_logits
|
|
|
| def to_all(self, imgs, labels, return_index=False):
|
|
|
| correct_index = ep.full_like(ep.astensors(torch.ones((len(imgs),), dtype=bool, device=self.device))[0], fill_value=True)
|
| models = load_all_task_models(self.args, self.args['logs_name'], self.data_manager,
|
| batch_size=self.batch_size,
|
| train=True,
|
| load_type='model')[0]
|
| for task in range(len(models)):
|
| model = PyTorchModel(models[task]._network, bounds=(0, 1), preprocessing=self.preprocessing)
|
| acc_bool = accuracy(model, imgs, labels)[1]
|
| correct_index = ep.logical_and(correct_index, acc_bool)
|
| del model, acc_bool
|
| if correct_index.any():
|
| imgs = imgs[correct_index]
|
| labels = labels[correct_index]
|
| logging.info(
|
| f"Filtering {len(labels)} Correct samples for all CL models.")
|
| else:
|
| logging.info("No valid samples found for IMGS, skipping this batch.")
|
| imgs, labels = None, None
|
| if return_index:
|
| return correct_index
|
| del models, correct_index
|
| return imgs, labels
|
|
|
|
|
| def __call__(
|
| self,
|
| model: Model,
|
| inputs: T,
|
| criterion: Any,
|
| *,
|
| epsilons: Sequence[Union[float, None]],
|
| **kwargs: Any,
|
| ) -> Tuple[List[T], List[T], T]:
|
| ...
|
|
|
| def repeat(self, times: int) -> "SustainableAttack":
|
| ...
|
|
|
|
|