Source code for shadow.module_wrapper

import torch


[docs]class ModuleWrapper(torch.nn.Module): r"""Base module wrapper for SSML technique implementations. Args: model (torch.nn.Module): The model to train with semi-supervised learning. """ def __init__(self, model): super(ModuleWrapper, self).__init__() self.model = model
[docs] def forward(self, x): r"""Passes data to the wrapped model. Args: x (torch.Tensor): Input data. Returns: torch.Tensor: Model output. """ return self.model(x)
[docs] def get_technique_cost(self, x): r"""Compute the SSML related cost for the implemented technique. Args: x (torch.Tensor): Input data. Returns: torch.Tensor: Technique specific cost. Raises: NotImplementedError: If not implemented in the specific technique. Note: This must be implemented for each specific technique that inherits from this base. """ raise NotImplementedError("Must override get_technique_cost function")