shadow.vat module¶
-
class
shadow.vat.
RPT
(eps, model, consistency_type='mse')[source]¶ Bases:
shadow.module_wrapper.ModuleWrapper
Random Perturbation Training for consistency regularization.
Random Perturbation Training (RPT) is a special case of Virtual Adversarial Training (VAT, [Miyato18]) for which the number of power iterations is 0. This means that added perturbations are isotropically random (not in the adversarial direction).
- Parameters
eps (float) – The magnitude of applied perturbation. Greater eps implies more smoothing.
model (torch.nn.Module) – The model to train and regularize.
consistency_type ({'kl', 'mse', 'mse_regress'}, optional) – Cost function used to measure consistency. Defaults to ‘kl’ (KL-divergence).
-
class
shadow.vat.
VAT
(model, xi=1.0, eps=1.0, power_iter=1, consistency_type='kl', flip_correction=True, xi_check=False)[source]¶ Bases:
shadow.module_wrapper.ModuleWrapper
Virtual Adversarial Training (VAT, [Miyato18]) model wrapper for consistency regularization.
- Parameters
model (torch.nn.Module) – The model to train and regularize.
xi (float, optional) – Scaling value for the random direction vector. Defaults to 1.0.
eps (float, optional) – The magnitude of applied adversarial perturbation. Greater eps implies more smoothing. Defaults to 1.0.
power_iter (int, optional) – Number of power iterations used to estimate virtual adversarial direction. Per [Miyato18], defaults to 1.
consistency_type ({'kl', 'mse', 'mse_regress'}, optional) – Cost function used to measure consistency. Defaults to ‘kl’ (KL-divergence).
flip_correction (bool, optional) – Correct flipped virtual adversarial perturbations induced by power iteration estimation. These iterations sometimes converge to a “flipped” perturbation (away from maximum change in consistency). This correction detects this behavior and corrects flipped perturbations at the cost of slightly increased compute. This behavior is not included in the original VAT implementation, which exhibits perturbation flipping without any corrections. Defaults to True.
xi_check (bool, optional) – Raise warnings for small perturbations lengths. It should be selected so as to be small (for algorithm assumptions to be correct), but not so small as to collapse the perturbation into a length 0 vector. This parameter controls optional warnings to detect a value of xi that causes perturbations to collapse to length 0. Defaults to False.
-
shadow.vat.
adv_perturbation
(x, y, model, criterion, optimizer)[source]¶ Find adversarial perturbation following [Goodfellow14].
- Parameters
x (torch.Tensor) – Input data.
y (torch.Tensor) – Input labels.
model (torch.nn.Module) – The model.
criterion (callable) – The loss criterion used to measure performance.
optimizer (torch.optim.Optimizer) – Optimizer used to compute gradients.
- Returns
Adversarial perturbations.
- Return type
torch.Tensor
-
shadow.vat.
l2_normalize
(r)[source]¶ L2 normalize tensor, flattening over all but batch dim (0).
- Parameters
r (torch.Tensor) – Tensor to normalize.
- Returns
Normalized tensor (over all but dim 0).
- Return type
torch.Tensor
-
shadow.vat.
rand_unit_sphere
(x)[source]¶ Draw samples from the uniform distribution over the unit sphere.
- Parameters
x (torch.Tensor) – Tensor used to define shape and dtype for the generated tensor.
- Returns
Random unit sphere samples.
- Return type
torch.Tensor
-
shadow.vat.
vadv_perturbation
(x, model, xi, eps, power_iter, consistency_criterion, flip_correction=True, xi_check=False)[source]¶ Find virtual adversarial perturbation following [Miyato18].
- Parameters
x (torch.Tensor) – Input data.
model (torch.nn.Module) – The model.
xi (float) – Scaling value for the random direction vector.
eps (float) – The magnitude of applied adversarial perturbation. Greater eps implies more smoothing.
power_iter (int) – Number of power iterations used in estimation.
consistency_criterion (callable) – Cost function used to measure consistency.
flip_correction (bool, optional) – Correct flipped virtual adversarial perturbations induced by power iteration estimation. These iterations sometimes converge to a “flipped” perturbation (away from maximum change in consistency). This correction detects this behavior and corrects flipped perturbations at the cost of slightly increased compute. This behavior is not included in the original VAT implementation, which exhibits perturbation flipping without any corrections. Defaults to True.
xi_check (bool, optional) – Raise warnings for small perturbations lengths. The parameter xi should be selected so as to be small (for algorithm assumptions to be correct), but not so small as to collapse the perturbation into a length 0 vector. This parameter controls optional warnings to detect a value of xi that causes perturbations to collapse to length 0. Defaults to False.
- Returns
Virtual adversarial perturbations.
- Return type
torch.Tensor