afterglow.trackers

Submodules

afterglow.trackers.batchnorm module

Adapted from torch.optim.swa_utils

afterglow.trackers.batchnorm.update_bn(loader: torch.utils.data.dataloader.DataLoader, model: torch.nn.modules.module.Module, device: Optional[Union[str, torch.device]] = None, num_datapoints: Optional[int] = None)[source]

Updates BatchNorm running_mean, running_var buffers in the model.

It performs one pass over data in loader to estimate the activation statistics for BatchNorm layers in the model. :param loader: dataset loader to compute the

activation statistics on. Each data batch should be either a tensor, or a list/tuple whose first element is a tensor containing data.

Parameters
  • model – model for which we seek to update BatchNorm statistics.

  • device – If set, data will be transferred to device before being passed into model.

  • num_datapoints – number of examples to use to perform the update.

Example

>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)

Note

The update_bn utility assumes that each data batch in loader is either a tensor or a list or tuple of tensors; in the latter case it is assumed that model.forward() should be called on the first element of the list or tuple corresponding to the data batch.

afterglow.trackers.trackers module

class afterglow.trackers.trackers.CheckpointTracker(module: torch.nn.modules.module.Module, start_iteration: int, update_period_in_iters: int, checkpoint_dir: Union[pathlib.Path, str])[source]

Bases: object

Records model state throughout training.

This class is intended for convenient offline swag enabling, see afterglow.enable.offline.enable_swag_from_checkpoints.

Parameters
  • module – module to enable tracking for.

  • start_iteration – iteration from which to begin recording snapshots.

  • update_period_in_iters – how often to observe the parameters, in interations

  • checkpoint_dir – directory in which to store the snapshots.

class afterglow.trackers.trackers.SWAGTracker(module: torch.nn.modules.module.Module, start_iteration: pydantic.types.StrictInt, max_cols: afterglow.trackers.trackers.ConstrainedIntValue, update_period_in_iters: pydantic.types.StrictInt, dataloader_for_batchnorm: Optional[torch.utils.data.dataloader.DataLoader] = None, num_datapoints_for_bn_update: Optional[pydantic.types.StrictInt] = None)[source]

Bases: object

Models the parameter distribution over the training trajectory as a multivariate gaussian in a low-rank space. See SWAG paper: https://arxiv.org/abs/1902.02476.

Parameters
  • module – module to enable tracking for.

  • start_iteration – iteration from which to begin fitting the approx posterior.

  • max_cols – the posterior covariance matrix is dimensionally reduced to this dimensionality. Must be greater than 1.

  • update_period_in_iters – how often to observe the parameters, in interations

  • dataloader_for_batchnorm – if this is is provided, we update the model’s batchnorm running means and variances every time we sample a new set of parameters using the data in the dataloader. This is slow but can improve performance significantly. See SWAG paper, and torch.optim.swa_utils.update_bn. Note that the assumptions made about what iterating over the dataloader returns are the same as those in torch.optim.swa_utils.update_bn: it’s assumed that iterating produces a sequence of (input_batch, label_batch) tuples.

  • num_datapoints_for_bn_update – Number of training example to use to perfom the batchnorm update. If None, we use the whole dataset, as in the original SWAG paper. It’s better to better to set this value to 1 and increase the number of SWAG samples drawn when predicting in online mode (one example at a time) rather than in batch mode. If this is not None, dataloader_for_batchnorm must be initialised with shuffle=True

predict_uncertainty(*args, num_samples: pydantic.types.StrictInt = 1, dropout: bool = False, prediction_key: Optional[str] = None, device: str = 'cpu', **kwargs) Tuple[torch.Tensor, torch.Tensor][source]

Predict mean and standard deviation of predictive distribution when model inputs are args, kwargs.

Parameters
  • *args – positional arguments to pass to the model

  • **kwargs – keyword arguments to pass to the model

  • num_samples – number of samples to draw

  • dropout – whether to use mc-dropout estimation together with SWAG when sampling

  • prediction_key – if the model returns a dict, this specifices which key contains its predictions; uncertainty will be computed using the contents of this key. Must be provided if the model returns a dict, otherwise ignored.

  • device – device on which the model lives. Only needed for the batchnorm update step, ignored if this doesn’t happen. “cpu” by default.

Returns

mean and standard deviation of the predictive distribution at the inputs args, kwargs

predict_uncertainty_on_dataloader(dataloader, num_samples: pydantic.types.StrictInt = 1, dropout: bool = False, prediction_key: Optional[str] = None, max_unreduced_minibatches: Optional[int] = None, device: str = 'cpu')[source]

Predict the mean and standard deviation of the model’s output distribution on the examples contained in dataloader. We assume that dataloader returns (input_batch, label_batch) tuples.

Parameters
  • dataloader – contains the examples to sample from the output distribution on

  • num_samples – number of samples to draw

  • dropout – whether to use mc-dropout estimation together with SWAG when sampling

  • prediction_key – if the model returns a dict, this specifices which key contains its predictions; uncertainty will be computed using the contents of this key. Must be provided if the model returns a dict, otherwise ignored.

  • max_unreduced_minibatches – the maximum number of minibatches whose samples to keep in memory before reducing to compute mean and standard deviation. Runs faster for larger values but takes more memory. If not provided, we accumulate samples for the whole dataloader before reducing.

  • device – device on which the model lives. Only needed for the batchnorm update step, ignored if this doesn’t happen. “cpu” by default.

Returns

mean and standard deviation of the predictive distribution for the examples contained in dataloader

predictive_samples(*args, num_samples: pydantic.types.StrictInt = 1, dropout: bool = False, device: str = 'cpu', **kwargs) torch.Tensor[source]
Produce samples from the model’s output distribution given

inputs args and kwargs.

Parameters
  • *args – positional arguments to pass to the model

  • **kwargs – keyword arguments to pass to the model

  • num_samples – number of samples to draw

  • dropout – whether to use mc-dropout estimation together with SWAG when sampling

  • device – device on which the model lives. Only needed for the batchnorm update step, ignored if this doesn’t happen. “cpu” by default.

  • Returns – list of samples from the predictive distribution

predictive_samples_on_dataloader(dataloader: torch.utils.data.dataloader.DataLoader, num_samples: pydantic.types.StrictInt = 1, dropout: bool = False, prediction_key: Optional[str] = None, device: str = 'cpu') torch.Tensor[source]

Produce samples from the model’s output distribution on the examples contained in dataloader. We assume that dataloader returns (input_batch, label_batch) tuples.

Parameters
  • dataloader – contains the examples to sample from the output distribution on

  • num_samples – number of samples to draw

  • dropout – whether to use mc-dropout estimation together with SWAG when sampling

  • prediction_key – if the model returns a dict, this specifices which key contains its predictions; uncertainty will be computed using the contents of this key. Must be provided if the model returns a dict, otherwise ignored.

  • device – device on which the model lives.

Returns

list of samples from the predictive distribution

sample_state(device: str = 'cpu')[source]

Update the state of the tracker’s module with a sample from the estimated distribution over parameters.

Parameters

device – where to send the data duing batchnorm update. Ignored if we don’t do batchnorm update.

save(path: Union[str, pathlib.Path])[source]

Save the uncertainty-enabled model so that it can be loaded using afterglow.load_swag_checkpoint.

Parameters

path – where to save the checkpoint

Module contents