MNIST ExampleΒΆ

For this demo, we load the MNIST (handwritten digits) dataset using torchvision, define a simple convolutional architecture, and train a prediction model using the exponential average adversarial training technique (EAAT) with 10% of the MNIST labels. This example is meant as a quick-start guide and to reinforce what is provided in the documentation.

[1]:
# torch imports
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# shadow-ssml imports
import shadow.eaat
from shadow.utils import set_seed

# helpers
import numpy as np
import random

Torchvision makes it easy to load and perform standard preprocessing operations on a variety of data transforms. Instead of using the MNIST class for the fully-labeled training datasets, we define our own MNIST class to return partially labeled (labeled and unlabeled) training data. Then we define our dataset for training as the MNIST training data with 90% of the labels reassigned to a value to -1 using a consistent sampling seed. Lastly, we use the standard torchvision MNIST class test partition, keeping all labels, for evaluation of SSL classification performance.

[2]:
datadir = 'data'
set_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class UnlabeledMNIST(torchvision.datasets.MNIST):
    def __init__(self, root, train=True,
                 transform=torchvision.transforms.ToTensor(),
                 download=False, unlabeled_frac=0.9):
        super(UnlabeledMNIST, self).__init__(root,
                 train=train, transform=transform,
                 download=download)
        labels_to_drop = np.random.choice(len(self),
                 size=int(len(self) * unlabeled_frac),
                 replace=False)
        self.targets[labels_to_drop] = -1


dataset = UnlabeledMNIST(datadir, train=True, download=True,
                         transform=torchvision.transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset, batch_size=100)

test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
    datadir, train=False, download=True,
    transform=torchvision.transforms.ToTensor()),
    batch_size=100, shuffle=True)
[3]:
print(dataset)
Dataset UnlabeledMNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()

Next we define our parameter dictionary for non-default parameters used by the EAAT technique. For example, we rarely require more than one power iteration to compute the adversarial direction. Likewise, we maintain defaults for student and teacher noise. As a reminder, EAAT is a combination of exponential averaging, which uses random gaussian perturbations, and adversarial training, which uses data-specific adversarial perturbations. If your dataset may benefit from additive noise AND adversarial perturbations, the EAAT parameters {student_noise, teacher_noise} would be included in the model and in hyperparameter searches.

[4]:
eaatparams = {
        "xi": 1e-8,
        "eps": 2.3,
        }

Here we define a simple convolutional architecture with Relu and Dropout. Forward, in this case, does not return Softmax on the final layer. Typically the loss for each technique implements Softmax scaling. We then instantiate the model and the optimizer.

[5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout = nn.Dropout2d(0.1)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


model = Net()
eaat = shadow.eaat.EAAT(model=model, **eaatparams)
optimizer = optim.SGD(eaat.parameters(), lr=0.01)

At this point, we have partially-labeled training data available through our train_loader and fully-labeled testing data from test_loader. We have initialized a model, specified that we plan to use EAAT, and passed the EAAT parameters to the model. The last step is to train the model. The loss function for the SSL techniques implemented here is a combination of the loss on labeled data, where we typically use cross-entropy, and the technique-specific consistency cost. We specify the labeled data cost (xEnt), ignoring labels of -1, which we used as the unlabeled target values. During training, we give the labeled loss and the consistency loss equal weight by simply adding them together.

[6]:
xEnt = torch.nn.CrossEntropyLoss(ignore_index=-1)

eaat.to(device)
losscurve = []
for epoch in range(10):
    eaat.train()
    lossavg = []
    for i, (data, targets) in enumerate(train_loader):
        x = data.to(device)
        y = targets.to(device)
        optimizer.zero_grad()
        out = eaat(x)
        loss = xEnt(out, y) + eaat.get_technique_cost(x)
        loss.backward()
        optimizer.step()
        lossavg.append(loss.item())
    losscurve.append(np.median(lossavg))
    print('epoch {} loss: {}'.format(epoch, losscurve[-1]))
epoch 0 loss: 1.6615383625030518
epoch 1 loss: 1.2582014799118042
epoch 2 loss: 1.0733909010887146
epoch 3 loss: 0.9297202229499817
epoch 4 loss: 0.8314944803714752
epoch 5 loss: 0.7584533393383026
epoch 6 loss: 0.6920907497406006
epoch 7 loss: 0.6233154237270355
epoch 8 loss: 0.5829548835754395
epoch 9 loss: 0.5472914576530457

After training, we evaluate the performance over our test set.

[7]:
eaat.eval()
y_pred, y_true = [], []
for i, (data, targets) in enumerate(test_loader):
    x = data.to(device)
    y = targets.to(device)
    out = eaat(x)
    y_true.extend(y.detach().cpu().tolist())
    y_pred.extend(torch.argmax(out, 1).detach().cpu().tolist())
test_acc = (np.array(y_true) == np.array(y_pred)).mean() * 100
print('test accuracy: {}'.format(test_acc))
test accuracy: 96.33
[ ]: