shadow.mt module

class shadow.mt.MT(model, alpha=0.999, noise=0.1, consistency_type='mse')[source]

Bases: shadow.module_wrapper.ModuleWrapper

Mean Teacher [Tarvainen17] model wrapper for consistency regularization.

Mean Teacher model wrapper the provides both student and teacher model implementation. The teacher model is a running average of the student weights, and is updated during training. When switched to eval mode, the teacher model is used for predictions instead of the student. As the wrapper handles the hand off between student and teacher models, the wrapper should be used instead of the student model directly.

Parameters
  • model (torch.nn.Module) – The student model.

  • alpha (float, optional) – The teacher exponential moving average smoothing coefficient. Defaults to 0.999.

  • noise (float, optional) – If > 0.0, the standard deviation of gaussian noise to apply to the input. Specifically, generates random numbers from a normal distribution with mean 0 and variance 1, and then scales them by this factor and adds to the input data. Defaults to 0.1.

  • consistency_type ({'kl', 'mse', 'mse_regress'}, optional) – Cost function used to measure consistency. Defaults to ‘mse’ (mean squared error).

calc_student_logits(x)[source]

Student model logits, with noise added to the input data.

Parameters

x (torch.Tensor) – Input data.

Returns

The student logits.

Return type

torch.Tensor

calc_teacher_logits(x)[source]

Teacher model logits.

The teacher model logits, with noise added to the input data. Does not propagate gradients in the teacher forward pass.

Parameters

x (torch.Tensor) – Input data.

Returns

The teacher logits.

Return type

torch.Tensor

forward(x)[source]

Model forward pass.

During model training, adds noise to the input data and passes through the student model. During model evaluation, does not add noise and passes through the teacher model.

Parameters

x (torch.Tensor) – Input data.

Returns

Model output.

Return type

torch.Tensor

get_evaluation_model()[source]

The teacher model, which should be used for prediction during evaluation.

Returns

The teacher model.

Return type

torch.nn.Module

get_technique_cost(x)[source]

Consistency cost between student and teacher models.

Consistency cost between the student and teacher, updates teacher weights via exponential moving average of the student weights. Noise is sampled and applied to student and teacher separately.

Parameters

x (torch.Tensor) – Input data.

Returns

Consistency cost between the student and teacher model outputs.

Return type

torch.Tensor

update_ema_model()[source]

Exponential moving average update of the teacher model.

shadow.mt.ema_update_model(student_model, ema_model, alpha, global_step)[source]

Exponential moving average update of a model.

Update ema_model to be the moving average of consecutive student_model updates via an exponential weighting (as defined in [Tarvainen17]). Update is performed in-place.

Parameters
  • student_model (torch.nn.Module) – The student model.

  • ema_model (torch.nn.Module) – The model to update (teacher model). Update is performed in-place.

  • alpha (float) – Exponential moving average smoothing coefficient, between [0, 1].

  • global_step (int) – A running count of exponential update steps (typically mini-batch updates).