Source code for shadow.vat

import torch
import shadow.losses
import shadow.utils
import shadow.module_wrapper
import warnings


[docs]def l2_normalize(r): r"""L2 normalize tensor, flattening over all but batch dim (0). Args: r (torch.Tensor): Tensor to normalize. Returns: torch.Tensor: Normalized tensor (over all but dim 0). """ # Compute length over all but sample dim norm = torch.norm(r.view(r.shape[0], -1), dim=1, keepdim=False) # Divide by the length, insert single dims into the length vector for # array broadcasting. Add a small perturbation for stability. return r / (norm.view(*norm.shape, *(1 for _ in r.shape[1:])) + torch.finfo(r.dtype).tiny)
[docs]def rand_unit_sphere(x): r"""Draw samples from the uniform distribution over the unit sphere. Args: x (torch.Tensor): Tensor used to define shape and dtype for the generated tensor. Returns: torch.Tensor: Random unit sphere samples. Reference: https://stats.stackexchange.com/questions/7977/how-to-generate-uniformly-distributed-points-on-the-surface-of-the-3-d-unit-sphe """ # noqa: E501 return l2_normalize(torch.randn_like(x))
[docs]class RPT(shadow.module_wrapper.ModuleWrapper): r"""Random Perturbation Training for consistency regularization. Random Perturbation Training (RPT) is a special case of Virtual Adversarial Training (VAT, [Miyato18]_) for which the number of power iterations is 0. This means that added perturbations are isotropically random (not in the adversarial direction). Args: eps (float): The magnitude of applied perturbation. Greater `eps` implies more smoothing. model (torch.nn.Module): The model to train and regularize. consistency_type ({'kl', 'mse', 'mse_regress'}, optional): Cost function used to measure consistency. Defaults to `'kl'` (KL-divergence). """ def __init__(self, eps, model, consistency_type="mse"): super(RPT, self).__init__(model) self.eps = eps if consistency_type == 'mse': self.consistency_criterion = shadow.losses.softmax_mse_loss elif consistency_type == 'kl': self.consistency_criterion = shadow.losses.softmax_kl_loss elif consistency_type == 'mse_regress': self.consistency_criterion = shadow.losses.mse_regress_loss else: raise ValueError( "Unknown consistency type. Should be 'mse', 'kl', or 'mse_regress', but is " + str(consistency_type) )
[docs] def get_technique_cost(self, x): r"""Consistency cost (local distributional smoothness). Args: x (torch.Tensor): Tensor of the data Returns: torch.Tensor: Consistency cost between the data and randomly perturbed data. """ with torch.no_grad(): model_logits = self.model(x) r = self.eps * rand_unit_sphere(x) model_logits_r = self.model(x + r) # The minibatch size is required in order to find the mean loss minibatch_size = x.shape[0] loss = self.consistency_criterion(model_logits, model_logits_r) / minibatch_size return loss
[docs]def adv_perturbation(x, y, model, criterion, optimizer): """Find adversarial perturbation following [Goodfellow14]_. Args: x (torch.Tensor): Input data. y (torch.Tensor): Input labels. model (torch.nn.Module): The model. criterion (callable): The loss criterion used to measure performance. optimizer (torch.optim.Optimizer): Optimizer used to compute gradients. Returns: torch.Tensor: Adversarial perturbations. """ x.requires_grad = True out = model(x) loss = criterion(out, y) optimizer.zero_grad() loss.backward() r_adv = l2_normalize(torch.sign(x.grad)) x.requires_grad = False return r_adv
[docs]def vadv_perturbation(x, model, xi, eps, power_iter, consistency_criterion, flip_correction=True, xi_check=False): r"""Find virtual adversarial perturbation following [Miyato18]_. Args: x (torch.Tensor): Input data. model (torch.nn.Module): The model. xi (float): Scaling value for the random direction vector. eps (float): The magnitude of applied adversarial perturbation. Greater `eps` implies more smoothing. power_iter (int): Number of power iterations used in estimation. consistency_criterion (callable): Cost function used to measure consistency. flip_correction (bool, optional): Correct flipped virtual adversarial perturbations induced by power iteration estimation. These iterations sometimes converge to a "flipped" perturbation (away from maximum change in consistency). This correction detects this behavior and corrects flipped perturbations at the cost of slightly increased compute. This behavior is not included in the original VAT implementation, which exhibits perturbation flipping without any corrections. Defaults to `True`. xi_check (bool, optional): Raise warnings for small perturbations lengths. The parameter `xi` should be selected so as to be small (for algorithm assumptions to be correct), but not so small as to collapse the perturbation into a length 0 vector. This parameter controls optional warnings to detect a value of `xi` that causes perturbations to collapse to length 0. Defaults to `False`. Returns: torch.Tensor: Virtual adversarial perturbations. """ # The minibatch size is required in order to find the mean loss minibatch_size = x.shape[0] # find regular probabilities with torch.no_grad(): # turn gradient off of the regular model out = model(x) # create random unit tensor d = rand_unit_sphere(x) # calculate adversarial direction for i in range(power_iter): d.requires_grad = True # find probability with the random tensor out_plus = model(x + xi * d) # compute the distance dist = consistency_criterion(out, out_plus) / minibatch_size dist.backward() # gradient w.r.t. r if distance isn't 0 d = l2_normalize(d.grad) model.zero_grad() r_adv = eps * l2_normalize(d.view(x.shape[0], -1)).view(x.shape) if xi_check: bools = torch.norm(r_adv.view(r_adv.shape[0], -1)) < eps / 10 if bools.any(): warnings.warn("generated perturbation vector has length smaller than eps/10," + " please check settings for xi", RuntimeWarning) if flip_correction: with torch.no_grad(): # Current per-sample distance currentD = consistency_criterion( out, model(x + r_adv), reduction="none").sum(dim=-1) # Flipped per-sample distance flippedD = consistency_criterion( out, model(x - r_adv), reduction="none").sum(dim=-1) # Flip if flipped distance is better than current flip = (currentD > flippedD).float() * 2 - 1 # Flip if needed, padding out with sufficient singleton dims r_adv = r_adv * flip.view(-1, *([1] * (x.dim() - 1))) return r_adv
[docs]class VAT(shadow.module_wrapper.ModuleWrapper): r"""Virtual Adversarial Training (VAT, [Miyato18]_) model wrapper for consistency regularization. Args: model (torch.nn.Module): The model to train and regularize. xi (float, optional): Scaling value for the random direction vector. Defaults to 1.0. eps (float, optional): The magnitude of applied adversarial perturbation. Greater `eps` implies more smoothing. Defaults to 1.0. power_iter (int, optional): Number of power iterations used to estimate virtual adversarial direction. Per [Miyato18]_, defaults to 1. consistency_type ({'kl', 'mse', 'mse_regress'}, optional): Cost function used to measure consistency. Defaults to `'kl'` (KL-divergence). flip_correction (bool, optional): Correct flipped virtual adversarial perturbations induced by power iteration estimation. These iterations sometimes converge to a "flipped" perturbation (away from maximum change in consistency). This correction detects this behavior and corrects flipped perturbations at the cost of slightly increased compute. This behavior is not included in the original VAT implementation, which exhibits perturbation flipping without any corrections. Defaults to `True`. xi_check (bool, optional): Raise warnings for small perturbations lengths. It should be selected so as to be small (for algorithm assumptions to be correct), but not so small as to collapse the perturbation into a length 0 vector. This parameter controls optional warnings to detect a value of `xi` that causes perturbations to collapse to length 0. Defaults to `False`. """ def __init__(self, model, xi=1.0, eps=1.0, power_iter=1, consistency_type="kl", flip_correction=True, xi_check=False): super(VAT, self).__init__(model) self.xi = xi self.xi_check = xi_check self.eps = eps self.power_iter = power_iter self.flip_correction = flip_correction if self.power_iter <= 0: self.power_iter = 1 if consistency_type == 'mse': self.consistency_criterion = shadow.losses.softmax_mse_loss elif consistency_type == 'kl': self.consistency_criterion = shadow.losses.softmax_kl_loss elif consistency_type == 'mse_regress': self.consistency_criterion = shadow.losses.mse_regress_loss else: raise ValueError( "Unknown consistency type. Should be 'mse', 'kl', or 'mse_regress', but is " + str(consistency_type) )
[docs] def get_technique_cost(self, x): r"""VAT consistency cost (local distributional smoothness). Args: x (torch.Tensor): Tensor of the data Returns: torch.Tensor: Consistency cost between the data and virtual adversarially perturbed data. """ with torch.no_grad(): out = self.model(x) r = vadv_perturbation(x, self.model, self.xi, self.eps, self.power_iter, self.consistency_criterion, flip_correction=self.flip_correction, xi_check=self.xi_check) out_plus = self.model(x + r) # The minibatch size is required in order to find the mean loss minibatch_size = x.shape[0] loss = self.consistency_criterion(out, out_plus) / minibatch_size return loss