Spaces:
Runtime error
Runtime error
| # Code taken and adapted from https://github.com/wagnermoritz/GSE | |
| import torch | |
| from vlm_eval.attacks.attack import Attack | |
| import math | |
| class IHT(Attack): | |
| def __init__(self, model, targeted=False, img_range=(0, 1), steps=100, prox='hard',ver=False, lam=5e-5, mask_out='none',stepsize=0.015,eps=4./255.): | |
| super().__init__(model, targeted=targeted, img_range=img_range) | |
| self.steps = steps | |
| self.stepsize = stepsize | |
| self.ver = ver | |
| self.lam = lam | |
| self.eps = eps | |
| if mask_out != 'none': | |
| self.mask_out = mask_out | |
| else: | |
| self.mask_out = None | |
| if prox == 'hard': | |
| self.Prox = self.hardprox | |
| else: | |
| raise NotImplementedError | |
| def _set_mask(self, data): | |
| mask = torch.ones_like(data) | |
| if self.mask_out == 'context': | |
| mask[:, :-1, ...] = 0 | |
| elif self.mask_out == 'query': | |
| mask[:, -1, ...] = 0 | |
| elif isinstance(self.mask_out, int): | |
| mask[:, self.mask_out, ...] = 0 | |
| elif self.mask_out is None: | |
| pass | |
| else: | |
| raise NotImplementedError(f'Unknown mask_out: {self.mask_out}') | |
| return mask | |
| def __call__(self, img): | |
| for param in self.model.model.parameters(): | |
| param.requires_grad = False | |
| img = img.to(self.device) | |
| mask_out = self._set_mask(img) | |
| x = torch.zeros_like(img) # perturbation to optimize | |
| z = x.clone() # used for FISTA extrapolation | |
| t = 1 | |
| if self.ver: | |
| print('') | |
| for i in range(self.steps): | |
| # compue gradient | |
| x.requires_grad = True | |
| loss = self.model(img + x).sum() if self.targeted else -self.model(img + x).sum() | |
| loss.backward() | |
| x_grad = x.grad.data * mask_out | |
| x = x.detach() | |
| if self.ver and i % 20 == 0: | |
| print(f'Iteration: {i+1}, Loss: {loss}\n', end='') | |
| # FISTA update | |
| with torch.no_grad(): | |
| t_ = .5 * (1 + math.sqrt(1 + 4 * t ** 2)) | |
| alpha = (t - 1) / t_ | |
| t = t_ | |
| z_ = self.Prox(x=x - self.stepsize * x_grad, | |
| lam=self.lam * self.stepsize, | |
| img=img, | |
| eps=self.eps | |
| ) | |
| x = z_ + alpha * (z_ - z) | |
| x = torch.clamp(x,-self.eps,self.eps) | |
| z = z_.clone() | |
| x = torch.clamp(img + x, *self.img_range) - img | |
| if self.ver: | |
| print('') | |
| print(f"L0 pert norm: {x.norm(p=0)}") | |
| return (img + x * mask_out).detach(), x.norm(p=0).item() | |
| def hardprox(self, x, lam, img, eps): | |
| ''' | |
| Computes the hard thresholding proximal operator of the the | |
| perturbation x. | |
| :x: Perturbation after gradient descent step. | |
| :lam: Regularization parameter. | |
| ''' | |
| x_proj = torch.clamp(x,-eps,eps) | |
| x_temp = torch.clamp(img + x_proj,*self.img_range) | |
| x_proj = x_temp - img | |
| return torch.where(x ** 2 - (x_proj - x) ** 2 > 2 * lam, x_proj, 0) | |