afterglow.enable

Submodules

afterglow.enable.base module

afterglow.enable.base.load_swag_checkpoint(base_module: torch.nn.modules.module.Module, path: Union[str, pathlib.Path], dataloader_for_batchnorm: Optional[torch.utils.data.dataloader.DataLoader] = None, num_datapoints_for_bn_update: Optional[pydantic.types.StrictInt] = None)[source]

Loads the state dict of a SWAG-enabled model that was saved via model.trajectory_tracker.save into base_model after enabling SWAG on base_model.

Parameters
  • module – An instance of the module to load the swag checkpoint into.

  • path – Path to the checkpoint

  • dataloader_for_batchnorm – see enable_swag_from_checkpoints.

  • num_datapoints_for_bn_update – see enable_swag_from_checkpoints.

afterglow.enable.offline module

afterglow.enable.offline.enable_checkpointing(module: torch.nn.modules.module.Module, start_iteration: pydantic.types.StrictInt, checkpoint_dir: Union[str, pathlib.Path], update_period_in_iters: Optional[pydantic.types.StrictInt] = None)[source]

Convenience function to save checkpoints during a run in a format that will work easily with enable_swag_from_checkpoints. If you use this function for checkpointing, you can call enable_swag_from_checkpoints with checkpoint_pattern and checkpoint_sort_key left as the defaults.

Parameters
  • module – the module to enable checkpointing for

  • start_iteration – iteration at which to start saving checkpoints

  • update_period_in_iters – how often to save the parameters, in interations

  • checkpoint_dir – directory to save the checkpoints in. Need not exist.

afterglow.enable.offline.enable_swag_from_checkpoints(module: torch.nn.modules.module.Module, max_cols: afterglow.enable.offline.ConstrainedIntValue, checkpoint_dir: pathlib.Path, start_iteration: pydantic.types.StrictInt = 0, checkpoint_pattern: str = '*.ckpt', checkpoint_sort_key: Callable[[str], float] = <function _iter_from_filepath>, dataloader_for_batchnorm: Optional[torch.utils.data.dataloader.DataLoader] = None, num_datapoints_for_bn_update: Optional[pydantic.types.StrictInt] = None) afterglow._types.SwagEnabledModule[source]

Equips a model with SWAG-based uncertainty estimation by reconstructing the training trajectory from a series of saved checkpoints. Useful if you have non-SWAG-enabled checkpoints saved for an expensive-to-train model that you want to try SWAG on.

Calling this on a model equips it with a trajectory_tracker object which provides SWAG-sampling methods. Example usage:

my_model = MyModel()
enable_swag_from_checkpoints(
    my_model,
    max_cols=10,
    checkpoint_dir="./checkpoints",
    checkpoint_pattern="*.pt",
    checkpoint_sort_key=lambda x: int(str(x.stem)),
) # assuming your checkpoints are of the form "./checkpoints/<epoch-num>.pt"
my_model.trajectory_tracker.predict_uncertainty(data)
Parameters
  • module – The module to enable SWAG for.

  • max_cols – Number of checkpoints to use in calculating the SWAG covariance matrix. Values between 10 and 20 are usually reasonable. See SWAG paper for details.

  • checkpoint_dir – Directory where the checkpoints from the training run you want to apply SWAG to are found.

  • start_iteration – iteration from which to begin recording snapshots.

  • checkpoint_pattern – A glob pattern that, when applied to checkpoint_dir, will select the checkpoints you want to include.

  • checkpoint_sort_key – Function mapping from checkpoint filenames to a number, where the number can be used to order the checkpoints.

  • dataloader_for_batchnorm – if this is is provided, we update the model’s batchnorm running means and variances every time we sample a new set of parameters using the data in the dataloader. This is slow but can improve performance significantly. See SWAG paper, and torch.optim.swa_utils.update_bn. Note that the assumptions made about what iterating over the dataloader returns are the same as those in torch.optim.swa_utils.update_bn: it’s assumed that iterating produces a sequence of (input_batch, label_batch) tuples.

  • num_datapoints_for_bn_update – Number of training example to use to perfom the batchnorm update. If None, we use the whole dataset, as in the original SWAG paper. It’s better to better to set this value to 1 and increase the number of SWAG samples drawn when predicting in online mode (one example at a time) rather than in batch mode. If this is not None, dataloader_for_batchnorm must be initialised with shuffle=True

afterglow.enable.online module

afterglow.enable.online.enable_swag(module: torch.nn.modules.module.Module, start_iteration: pydantic.types.StrictInt, max_cols: afterglow._types.ConstrainedIntValue, update_period_in_iters: pydantic.types.StrictInt, dataloader_for_batchnorm: Optional[torch.utils.data.dataloader.DataLoader] = None, num_datapoints_for_bn_update: Optional[pydantic.types.StrictInt] = None) afterglow._types.SwagEnabledModule[source]

Enables online trajectory tracking. Models passed to this function and subsequently trained will update SWAG buffers during training, and will be equiped with the ability to sample from the SWAG posterior via a trajectory_tracker object once training is done.

Calling this on a model equips it with a trajectory_tracker object which provides SWAG-sampling methods. Example usage:

my_model = MyModel()
enable_swag(
    my_model,
    max_cols=10,
    update_period_in_iters: len(train_dataloader), # update once per epoch
)
trainer.fit(my_model, train_dataloader)
my_model.trajectory_tracker.predict_uncertainty(data)
Parameters
  • max_cols – Number of checkpoints to use in calculating the SWAG covariance matrix. Values between 10 and 20 are usually reasonable. See SWAG paper for details.

  • update_period_in_iters – The interval between SWAG buffer updates, in iterations. This is usually set to the number of iterations per epoch, which you can get with len(train_dataloader).

  • dataloader_for_batchnorm – if this is is provided, we update the model’s batchnorm running means and variances every time we sample a new set of parameters using the data in the dataloader. This is slow but can improve performance significantly. See SWAG paper, and torch.optim.swa_utils.update_bn. Note that the assumptions made about what iterating over the dataloader returns are the same as those in torch.optim.swa_utils.update_bn: it’s assumed that iterating produces a sequence of (input_batch, label_batch) tuples.

  • num_datapoints_for_bn_update – Number of training example to use to perfom the batchnorm update. If None, we use the whole dataset, as in the original SWAG paper. It’s better to better to set this value to 1 and increase the number of SWAG samples drawn when predicting in online mode (one example at a time) rather than in batch mode. If this is not None, dataloader_for_batchnorm must be initialised with shuffle=True

Module contents