Module laplace.utils.swag
Functions
def fit_diagonal_swag_var(model: nn.Module, train_loader: DataLoader, criterion: nn.CrossEntropyLoss | nn.MSELoss, n_snapshots_total: int = 40, snapshot_freq: int = 1, lr: float = 0.01, momentum: float = 0.9, weight_decay: float = 0.0003, min_var: float = 1e-30) ‑> torch.Tensor-
Fit diagonal SWAG [1], which estimates marginal variances of model parameters by computing the first and second moment of SGD iterates with a large learning rate.
Implementation partly adapted from: - https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py - https://github.com/wjmaddox/swa_gaussian/blob/master/experiments/train/run_swag.py
References
[1] Maddox, W., Garipov, T., Izmailov, P., Vetrov, D., Wilson, AG. A Simple Baseline for Bayesian Uncertainty in Deep Learning. NeurIPS 2019.
Parameters
model:torch.nn.Moduletrain_loader:torch.data.utils.DataLoader- training data loader to use for snapshot collection
criterion:torch.nn.CrossEntropyLossortorch.nn.MSELoss- loss function to use for snapshot collection
n_snapshots_total:int- total number of model snapshots to collect
snapshot_freq:int- snapshot collection frequency (in epochs)
lr:float- SGD learning rate for collecting snapshots
momentum:float- SGD momentum
weight_decay:float- SGD weight decay
min_var:float- minimum parameter variance to clamp to (for numerical stability)
Returns
param_variances:torch.Tensor- vector of marginal variances for each model parameter