shadow.vat module

class shadow.vat.RPT(eps, model, consistency_type='mse')[source]

Bases: shadow.module_wrapper.ModuleWrapper

Random Perturbation Training for consistency regularization.

Random Perturbation Training (RPT) is a special case of Virtual Adversarial Training (VAT, [Miyato18]) for which the number of power iterations is 0. This means that added perturbations are isotropically random (not in the adversarial direction).

Parameters
  • eps (float) – The magnitude of applied perturbation. Greater eps implies more smoothing.

  • model (torch.nn.Module) – The model to train and regularize.

  • consistency_type ({'kl', 'mse', 'mse_regress'}, optional) – Cost function used to measure consistency. Defaults to ‘kl’ (KL-divergence).

get_technique_cost(x)[source]

Consistency cost (local distributional smoothness).

Parameters

x (torch.Tensor) – Tensor of the data

Returns

Consistency cost between the data and randomly perturbed data.

Return type

torch.Tensor

class shadow.vat.VAT(model, xi=1.0, eps=1.0, power_iter=1, consistency_type='kl', flip_correction=True, xi_check=False)[source]

Bases: shadow.module_wrapper.ModuleWrapper

Virtual Adversarial Training (VAT, [Miyato18]) model wrapper for consistency regularization.

Parameters
  • model (torch.nn.Module) – The model to train and regularize.

  • xi (float, optional) – Scaling value for the random direction vector. Defaults to 1.0.

  • eps (float, optional) – The magnitude of applied adversarial perturbation. Greater eps implies more smoothing. Defaults to 1.0.

  • power_iter (int, optional) – Number of power iterations used to estimate virtual adversarial direction. Per [Miyato18], defaults to 1.

  • consistency_type ({'kl', 'mse', 'mse_regress'}, optional) – Cost function used to measure consistency. Defaults to ‘kl’ (KL-divergence).

  • flip_correction (bool, optional) – Correct flipped virtual adversarial perturbations induced by power iteration estimation. These iterations sometimes converge to a “flipped” perturbation (away from maximum change in consistency). This correction detects this behavior and corrects flipped perturbations at the cost of slightly increased compute. This behavior is not included in the original VAT implementation, which exhibits perturbation flipping without any corrections. Defaults to True.

  • xi_check (bool, optional) – Raise warnings for small perturbations lengths. It should be selected so as to be small (for algorithm assumptions to be correct), but not so small as to collapse the perturbation into a length 0 vector. This parameter controls optional warnings to detect a value of xi that causes perturbations to collapse to length 0. Defaults to False.

get_technique_cost(x)[source]

VAT consistency cost (local distributional smoothness).

Parameters

x (torch.Tensor) – Tensor of the data

Returns

Consistency cost between the data and virtual adversarially perturbed data.

Return type

torch.Tensor

shadow.vat.adv_perturbation(x, y, model, criterion, optimizer)[source]

Find adversarial perturbation following [Goodfellow14].

Parameters
  • x (torch.Tensor) – Input data.

  • y (torch.Tensor) – Input labels.

  • model (torch.nn.Module) – The model.

  • criterion (callable) – The loss criterion used to measure performance.

  • optimizer (torch.optim.Optimizer) – Optimizer used to compute gradients.

Returns

Adversarial perturbations.

Return type

torch.Tensor

shadow.vat.l2_normalize(r)[source]

L2 normalize tensor, flattening over all but batch dim (0).

Parameters

r (torch.Tensor) – Tensor to normalize.

Returns

Normalized tensor (over all but dim 0).

Return type

torch.Tensor

shadow.vat.rand_unit_sphere(x)[source]

Draw samples from the uniform distribution over the unit sphere.

Parameters

x (torch.Tensor) – Tensor used to define shape and dtype for the generated tensor.

Returns

Random unit sphere samples.

Return type

torch.Tensor

Reference:

https://stats.stackexchange.com/questions/7977/how-to-generate-uniformly-distributed-points-on-the-surface-of-the-3-d-unit-sphe

shadow.vat.vadv_perturbation(x, model, xi, eps, power_iter, consistency_criterion, flip_correction=True, xi_check=False)[source]

Find virtual adversarial perturbation following [Miyato18].

Parameters
  • x (torch.Tensor) – Input data.

  • model (torch.nn.Module) – The model.

  • xi (float) – Scaling value for the random direction vector.

  • eps (float) – The magnitude of applied adversarial perturbation. Greater eps implies more smoothing.

  • power_iter (int) – Number of power iterations used in estimation.

  • consistency_criterion (callable) – Cost function used to measure consistency.

  • flip_correction (bool, optional) – Correct flipped virtual adversarial perturbations induced by power iteration estimation. These iterations sometimes converge to a “flipped” perturbation (away from maximum change in consistency). This correction detects this behavior and corrects flipped perturbations at the cost of slightly increased compute. This behavior is not included in the original VAT implementation, which exhibits perturbation flipping without any corrections. Defaults to True.

  • xi_check (bool, optional) – Raise warnings for small perturbations lengths. The parameter xi should be selected so as to be small (for algorithm assumptions to be correct), but not so small as to collapse the perturbation into a length 0 vector. This parameter controls optional warnings to detect a value of xi that causes perturbations to collapse to length 0. Defaults to False.

Returns

Virtual adversarial perturbations.

Return type

torch.Tensor