shadow.mt module¶
-
class
shadow.mt.
MT
(model, alpha=0.999, noise=0.1, consistency_type='mse')[source]¶ Bases:
shadow.module_wrapper.ModuleWrapper
Mean Teacher [Tarvainen17] model wrapper for consistency regularization.
Mean Teacher model wrapper the provides both student and teacher model implementation. The teacher model is a running average of the student weights, and is updated during training. When switched to eval mode, the teacher model is used for predictions instead of the student. As the wrapper handles the hand off between student and teacher models, the wrapper should be used instead of the student model directly.
- Parameters
model (torch.nn.Module) – The student model.
alpha (float, optional) – The teacher exponential moving average smoothing coefficient. Defaults to 0.999.
noise (float, optional) – If > 0.0, the standard deviation of gaussian noise to apply to the input. Specifically, generates random numbers from a normal distribution with mean 0 and variance 1, and then scales them by this factor and adds to the input data. Defaults to 0.1.
consistency_type ({'kl', 'mse', 'mse_regress'}, optional) – Cost function used to measure consistency. Defaults to ‘mse’ (mean squared error).
-
calc_student_logits
(x)[source]¶ Student model logits, with noise added to the input data.
- Parameters
x (torch.Tensor) – Input data.
- Returns
The student logits.
- Return type
torch.Tensor
-
calc_teacher_logits
(x)[source]¶ Teacher model logits.
The teacher model logits, with noise added to the input data. Does not propagate gradients in the teacher forward pass.
- Parameters
x (torch.Tensor) – Input data.
- Returns
The teacher logits.
- Return type
torch.Tensor
-
forward
(x)[source]¶ Model forward pass.
During model training, adds noise to the input data and passes through the student model. During model evaluation, does not add noise and passes through the teacher model.
- Parameters
x (torch.Tensor) – Input data.
- Returns
Model output.
- Return type
torch.Tensor
-
get_evaluation_model
()[source]¶ The teacher model, which should be used for prediction during evaluation.
- Returns
The teacher model.
- Return type
torch.nn.Module
-
get_technique_cost
(x)[source]¶ Consistency cost between student and teacher models.
Consistency cost between the student and teacher, updates teacher weights via exponential moving average of the student weights. Noise is sampled and applied to student and teacher separately.
- Parameters
x (torch.Tensor) – Input data.
- Returns
Consistency cost between the student and teacher model outputs.
- Return type
torch.Tensor
-
shadow.mt.
ema_update_model
(student_model, ema_model, alpha, global_step)[source]¶ Exponential moving average update of a model.
Update ema_model to be the moving average of consecutive student_model updates via an exponential weighting (as defined in [Tarvainen17]). Update is performed in-place.
- Parameters
student_model (torch.nn.Module) – The student model.
ema_model (torch.nn.Module) – The model to update (teacher model). Update is performed in-place.
alpha (float) – Exponential moving average smoothing coefficient, between [0, 1].
global_step (int) – A running count of exponential update steps (typically mini-batch updates).