Source code for shadow.pseudo

import torch
import shadow.module_wrapper


[docs]class Threshold(torch.nn.Module): r"""Per-class thresholding operator. Args: threshold (torch.Tensor): 1D `float` array of thresholds with length equal to the number of classes. Each element should be between :math:`[0, 1]` and represents a per-class threshold. Thresholds are with respect to normalized scores (e.g. they sum to 1). Example: >>> myThresholder = Threshold([.8, .9]) >>> myThresholder([[10, 90], [95, 95.4], [0.3, 0.4]]) [1, 0, 0] """ def __init__(self, thresholds): super(Threshold, self).__init__() self.thresholds = torch.nn.Parameter(torch.Tensor(thresholds))
[docs] def forward(self, predictions): r"""Threshold multi-class scores. Args: predictions (torch.Tensor): 2D model outputs of shape `(n_samples, n_classes)`. Does not need to be normalized in advance. Returns: torch.Tensor: binary thresholding for each sample. """ predictions = torch.nn.functional.softmax(predictions, dim=1) return torch.any(predictions > self.thresholds, dim=1).float()
[docs]class PL(shadow.module_wrapper.ModuleWrapper): r"""Pseudo Label model wrapper. The pseudo labeling wrapper weight samples according to model score. This is a form of entropy regularization. For example, a binary random variable with distribution :math:`P(X=1) = .5` and :math:`P(X=0) = .5` has a much higher entropy than :math:`P(X=1) = .9` and :math:`P(X=0) = .1`. Args: weight_function (callable): assigns weighting based on raw model outputs. ssml_mode (bool, optional): semi-supevised learning mode, toggles whether loss is computed for all inputs or just those data with missing labels. Defaults to True. missing_label (int, optional): integer value used to represent missing labels. Defaults to -1. """ def __init__(self, model, weight_function, ssml_mode=True, missing_label=-1): super(PL, self).__init__(model) self.weight_function = weight_function self.ssml_mode = ssml_mode self.missing_label = missing_label # if ssml_mode self.loss = torch.nn.CrossEntropyLoss(reduction='none')
[docs] def get_technique_cost(self, x, targets): r"""Compute loss from pseudo labeling. Args: x (torch.Tensor): Tensor of the data targets (torch.Tensor): 1D Corresponding labels. Unlabeled data is specified according to `self.missing_label`. Returns: torch.Tensor: Pseudo label loss. """ # Get the predicted labels from the model predictions = self.model(x) predicted_labels = torch.argmax(predictions, 1) # Unlabeled samples in batch unlabeled = targets == self.missing_label # For unlabeled data, 'fake' the labels semi_fake_targets = torch.where(unlabeled, predicted_labels, targets) # Per sample weighting samples _weights = self.weight_function(predictions).float() # If in ssml mode, don't weight the labeled samples if self.ssml_mode: _weights[~unlabeled] = 0 # Calculate cross entropy loss for all predictions regardless if labeled or unconfident per_sample_loss = self.loss(predictions, semi_fake_targets) * _weights # Only count those samples with non-zero weight indexes_to_keep = _weights != 0 # If there are no samples to keep, return 0 for the loss if indexes_to_keep.sum() == 0: return torch.Tensor([0]).to(per_sample_loss.device) return per_sample_loss[indexes_to_keep].mean()