shadow.module_wrapper module

class shadow.module_wrapper.ModuleWrapper(model)[source]

Bases: torch.nn.Module

Base module wrapper for SSML technique implementations.

Parameters

model (torch.nn.Module) – The model to train with semi-supervised learning.

forward(x)[source]

Passes data to the wrapped model.

Parameters

x (torch.Tensor) – Input data.

Returns

Model output.

Return type

torch.Tensor

get_technique_cost(x)[source]

Compute the SSML related cost for the implemented technique.

Parameters

x (torch.Tensor) – Input data.

Returns

Technique specific cost.

Return type

torch.Tensor

Raises

NotImplementedError – If not implemented in the specific technique.

Note

This must be implemented for each specific technique that inherits from this base.