Module laplace.curvature
Sub-modules
laplace.curvature.asdfghjkllaplace.curvature.asdllaplace.curvature.backpacklaplace.curvature.curvaturelaplace.curvature.curvlinops
Classes
class CurvatureInterface (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Interface to access curvature for a model and corresponding likelihood. A
CurvatureInterfacemust inherit from this baseclass and implement the necessary functionsjacobians,full,kron, anddiag. The interface might be extended in the future to account for other curvature structures, for example, a block-diagonal one.Parameters
model:torch.nn.ModuleorFeatureExtractor- torch model (neural network)
likelihood:{'classification', 'regression'}last_layer:bool, default=False- only consider curvature of last layer
subnetwork_indices:torch.LongTensor, default=None- indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over
dict_key_x:str, default='input_ids'- The dictionary key under which the input tensor
xis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models. dict_key_y:str, default='labels'- The dictionary key under which the target tensor
yis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models.
Attributes
lossfunc:torch.nn.MSELossortorch.nn.CrossEntropyLossfactor:float- conversion factor between torch losses and base likelihoods For example, \frac{1}{2} to get to \mathcal{N}(f, 1) from MSELoss.
Subclasses
Methods
def jacobians(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], enable_backprop: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta, via torch.func.
Parameters
x:torch.Tensor- input data
(batch, input_shape)on compatible device with model. enable_backprop:bool, default= False- whether to enable backprop through the Js and f w.r.t. x
Returns
Js:torch.Tensor- Jacobians
(batch, parameters, outputs) f:torch.Tensor- output function
(batch, outputs)
def last_layer_jacobians(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], enable_backprop: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute Jacobians \nabla_{\theta_\textrm{last}} f(x;\theta_\textrm{last}) only at current last-layer parameter \theta_{\textrm{last}}.
Parameters
x:torch.Tensorenable_backprop:bool, default=False
Returns
Js:torch.Tensor- Jacobians
(batch, outputs, last-layer-parameters) f:torch.Tensor- output function
(batch, outputs)
def gradients(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute batch gradients \nabla_\theta \ell(f(x;\theta, y) at current parameter \theta.
Parameters
x:torch.Tensor- input data
(batch, input_shape)on compatible device with model. y:torch.Tensor
Returns
Gs:torch.Tensor- gradients
(batch, parameters) loss:torch.Tensor
def full(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor, **kwargs: dict[str, Any])-
Compute a dense curvature (approximation) in the form of a P \times P matrix H with respect to parameters \theta \in \mathbb{R}^P.
Parameters
x:torch.Tensor- input data
(batch, input_shape) y:torch.Tensor- labels
(batch, label_shape)
Returns
loss:torch.TensorH:torch.Tensor- Hessian approximation
(parameters, parameters)
def kron(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor, N: int, **kwargs: dict[str, Any]) ‑> tuple[torch.Tensor, Kron]-
Compute a Kronecker factored curvature approximation (such as KFAC). The approximation to H takes the form of two Kronecker factors Q, H, i.e., H \approx Q \otimes H for each Module in the neural network permitting such curvature. Q is quadratic in the input-dimension of a module p_{in} \times p_{in} and H in the output-dimension p_{out} \times p_{out}.
Parameters
x:torch.Tensor- input data
(batch, input_shape) y:torch.Tensor- labels
(batch, label_shape) N:int- total number of data points
Returns
loss:torch.TensorH:Kron- Kronecker factored Hessian approximation.
def diag(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor, **kwargs: dict[str, Any])-
Compute a diagonal Hessian approximation to H and is represented as a vector of the dimensionality of parameters \theta.
Parameters
x:torch.Tensor- input data
(batch, input_shape) y:torch.Tensor- labels
(batch, label_shape)
Returns
loss:torch.TensorH:torch.Tensor- vector representing the diagonal of H
def functorch_jacobians(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], enable_backprop: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta, via torch.func.
Parameters
x:torch.Tensor- input data
(batch, input_shape)on compatible device with model. enable_backprop:bool, default= False- whether to enable backprop through the Js and f w.r.t. x
Returns
Js:torch.Tensor- Jacobians
(batch, parameters, outputs) f:torch.Tensor- output function
(batch, outputs)
class GGNInterface (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False, num_samples: int = 1)-
Generalized Gauss-Newton or Fisher Curvature Interface. The GGN is equal to the Fisher information for the available likelihoods. In addition to
CurvatureInterface, methods for Jacobians are required by subclasses.Parameters
model:torch.nn.ModuleorFeatureExtractor- torch model (neural network)
likelihood:{'classification', 'regression'}last_layer:bool, default=False- only consider curvature of last layer
subnetwork_indices:torch.Tensor, default=None- indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over
dict_key_x:str, default='input_ids'- The dictionary key under which the input tensor
xis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models. dict_key_y:str, default='labels'- The dictionary key under which the target tensor
yis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models. stochastic:bool, default=False- Fisher if stochastic else GGN
num_samples:int, default=1- Number of samples used to approximate the stochastic Fisher
Ancestors
Subclasses
Methods
def full(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor, **kwargs: dict[str, Any]) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute the full GGN P \times P matrix as Hessian approximation H_{ggn} with respect to parameters \theta \in \mathbb{R}^P. For last-layer, reduced to \theta_{last}
Parameters
x:torch.Tensor- input data
(batch, input_shape) y:torch.Tensor- labels
(batch, label_shape)
Returns
loss:torch.TensorH:torch.Tensor- GGN
(parameters, parameters)
Inherited members
class EFInterface (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Interface for Empirical Fisher as Hessian approximation. In addition to
CurvatureInterface, methods for gradients are required by subclasses.Parameters
model:torch.nn.ModuleorFeatureExtractor- torch model (neural network)
likelihood:{'classification', 'regression'}last_layer:bool, default=False- only consider curvature of last layer
subnetwork_indices:torch.Tensor, default=None- indices of the vectorized model parameters that define the subnetwork to apply the Laplace approximation over
dict_key_x:str, default='input_ids'- The dictionary key under which the input tensor
xis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models. dict_key_y:str, default='labels'- The dictionary key under which the target tensor
yis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models.
Attributes
lossfunc:torch.nn.MSELossortorch.nn.CrossEntropyLossfactor:float- conversion factor between torch losses and base likelihoods For example, \frac{1}{2} to get to \mathcal{N}(f, 1) from MSELoss.
Ancestors
Subclasses
Methods
def full(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor, **kwargs: dict[str, Any]) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute the full EF P \times P matrix as Hessian approximation H_{ef} with respect to parameters \theta \in \mathbb{R}^P. For last-layer, reduced to \theta_{last}
Parameters
x:torch.Tensor- input data
(batch, input_shape) y:torch.Tensor- labels
(batch, label_shape)
Returns
loss:torch.TensorH_ef:torch.Tensor- EF
(parameters, parameters)
Inherited members
class BackPackInterface (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Interface for Backpack backend.
Ancestors
Subclasses
Methods
def jacobians(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], enable_backprop: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.
Parameters
x:torch.Tensor- input data
(batch, input_shape)on compatible device with model. enable_backprop:bool, default= False- whether to enable backprop through the Js and f w.r.t. x
Returns
Js:torch.Tensor- Jacobians
(batch, parameters, outputs) f:torch.Tensor- output function
(batch, outputs)
def gradients(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute gradients \nabla_\theta \ell(f(x;\theta, y) at current parameter \theta using Backpack's BatchGrad. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.
Parameters
x:torch.Tensor- input data
(batch, input_shape)on compatible device with model. y:torch.Tensor
Returns
Gs:torch.Tensor- gradients
(batch, parameters) loss:torch.Tensor
Inherited members
class BackPackGGN (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False)-
Implementation of the
GGNInterfaceusing Backpack.Ancestors
Inherited members
class BackPackEF (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Implementation of
EFInterfaceusing Backpack.Ancestors
Inherited members
class AsdlInterface (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Interface for asdfghjkl backend.
Ancestors
Subclasses
Instance variables
var loss_type : str
Methods
def jacobians(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], enable_backprop: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute Jacobians \nabla_\theta f(x;\theta) at current parameter \theta using asdfghjkl's gradient per output dimension.
Parameters
x:torch.TensororMutableMapping (e.g. dict, UserDict)- input data
(batch, input_shape)on compatible device with model if torch.Tensor. If MutableMapping, then at least containsself.dict_key_x. The latter is specific for reward modeling. enable_backprop:bool, default= False- whether to enable backprop through the Js and f w.r.t. x
Returns
Js:torch.Tensor- Jacobians
(batch, parameters, outputs) f:torch.Tensor- output function
(batch, outputs)
def gradients(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute gradients \nabla_\theta \ell(f(x;\theta, y) at current parameter \theta using asdfghjkl's backend.
Parameters
x:torch.Tensor- input data
(batch, input_shape)on compatible device with model. y:torch.Tensor
Returns
loss:torch.TensorGs:torch.Tensor- gradients
(batch, parameters)
Inherited members
class AsdlGGN (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False)-
Implementation of the
GGNInterfaceusing asdfghjkl.Ancestors
Inherited members
class AsdlEF (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Implementation of the
EFInterfaceusing asdfghjkl.Ancestors
Inherited members
class AsdlHessian (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Interface for asdfghjkl backend.
Ancestors
Inherited members
class CurvlinopsInterface (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Interface for Curvlinops backend. https://github.com/f-dangel/curvlinops
Ancestors
Subclasses
Inherited members
class CurvlinopsGGN (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False)-
Implementation of the
GGNInterfaceusing Curvlinops.Ancestors
Inherited members
class CurvlinopsEF (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Implementation of
EFInterfaceusing Curvlinops.Ancestors
Inherited members
class CurvlinopsHessian (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Implementation of the full Hessian using Curvlinops.
Ancestors
Inherited members