Module laplace.lllaplace
Classes
class LLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)-
Baseclass for all last-layer Laplace approximations in this library. Subclasses specify the structure of the Hessian approximation. See
BaseLaplacefor the full interface.A Laplace approximation is represented by a MAP which is given by the
modelparameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). Here, only the parameters of the last layer of the neural network are treated probabilistically. The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.Parameters
model:torch.nn.ModuleorFeatureExtractorlikelihood:Likelihoodor{'classification', 'regression'}- determines the log likelihood Hessian approximation
sigma_noise:torch.Tensororfloat, default=1- observation noise for the regression setting; must be 1 for classification
prior_precision:torch.Tensororfloat, default=1- prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean:torch.Tensororfloat, default=0- prior mean of a Gaussian prior, useful for continual learning
temperature:float, default=1- temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
enable_backprop:bool, default=False- whether to enable backprop to the input
xthrough the Laplace predictive. Useful for e.g. Bayesian optimization. feature_reduction:FeatureReductionorstr, optional, default=None- when the last-layer
featuresis a tensor of dim >= 3, this tells how to reduce it into a dim-2 tensor. E.g. in LLMs for non-language modeling problems, the penultultimate output is a tensor of shape(batch_size, seq_len, embd_dim). But the last layer maps(batch_size, embd_dim)to(batch_size, n_classes). Note: Make sure that this option faithfully reflects the reduction in the model definition. When inputting a string, available options are{'pick_first', 'pick_last', 'average'}. dict_key_x:str, default='input_ids'- The dictionary key under which the input tensor
xis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models. dict_key_y:str, default='labels'- The dictionary key under which the target tensor
yis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models. backend:subclassesofCurvatureInterface- backend for access to curvature/Hessian approximations
last_layer_name:str, default=None- name of the model's last layer, if None it will be determined automatically
backend_kwargs:dict, default=None- arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.
Ancestors
Subclasses
Instance variables
var prior_precision_diag : torch.Tensor-
Obtain the diagonal prior precision p_0 constructed from either a scalar or diagonal prior precision.
Returns
prior_precision_diag:torch.Tensor
Methods
def fit(self, train_loader: DataLoader, override: bool = True, progress_bar: bool = False) ‑> None-
Fit the local Laplace approximation at the parameters of the model.
Parameters
train_loader:torch.data.utils.DataLoader- each iterate is a training batch, either
(X, y)tensors or a dict-like object containing keys as expressed byself.dict_key_xandself.dict_key_y.train_loader.datasetneeds to be set to access N, size of the data set. override:bool, default=True- whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.
progress_bar:bool, default=False
def functional_variance_fast(self, X)-
Should be overriden if there exists a trick to make this fast!
Parameters
X:torch.Tensorofshape (batch_size, input_dim)
Returns
f_var_diag:torch.Tensorofshape (batch_size, num_outputs)- Corresponding to the diagonal of the covariance matrix of the outputs
def state_dict(self) ‑> dict[str, typing.Any]def load_state_dict(self, state_dict: dict[str, Any]) ‑> None
Inherited members
class FullLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)-
Last-layer Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen
backendparameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. SeeFullLaplace,LLLaplace, andBaseLaplacefor the full interface.Ancestors
Inherited members
class KronLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, damping: bool = False, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)-
Last-layer Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for the last parameter group, i.e., torch.nn.Linear, that \P\approx Q \otimes H. See
KronLaplace,LLLaplace, andBaseLaplacefor the full interface and seeKronandKronDecomposedfor the structure of the Kronecker factors.Kronis used to aggregate factors by summing up andKronDecomposedis used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Use ofdampingis possible by initializing or settingdamping=True.Ancestors
Inherited members
class DiagLLLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction | str | None = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, last_layer_name: str | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)-
Last-layer Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See
DiagLaplace,LLLaplace, andBaseLaplacefor the full interface.Ancestors
Inherited members
class FunctionalLLLaplace (model: nn.Module, likelihood: Likelihood | str, n_subset: int, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, feature_reduction: FeatureReduction = None, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', last_layer_name: str = None, backend: type[CurvatureInterface] | None = laplace.curvature.backpack.BackPackGGN, backend_kwargs: dict[str, Any] | None = None, independent_outputs: bool = False, seed: int = 0)-
Here not much changes in terms of GP inference compared to FunctionalLaplace class. Since now we treat only the last layer probabilistically and the rest of the network is used as a "fixed feature extractor", that means that the X \in \mathbb{R}^{M \times D} in GP inference changes to \tilde{X} \in \mathbb{R}^{M \times l_{n-1}} , where l_{n-1} is the dimension of the output of the penultimate NN layer.
See
FunctionalLaplacefor the full interface.Ancestors
Methods
def fit(self, train_loader: DataLoader) ‑> None-
Fit the Laplace approximation of a GP posterior.
Parameters
train_loader:torch.data.utils.DataLoadertrain_loader.datasetneeds to be set to access N, size of the data settrain_loader.batch_sizeneeds to be set to access b batch_size
def state_dict(self) ‑> dictdef load_state_dict(self, state_dict: dict)
Inherited members