Module laplace.curvature.backpack
Classes
class BackPackInterface (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Interface for Backpack backend.
Ancestors
Subclasses
Methods
def jacobians(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], enable_backprop: bool = False) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute Jacobians \nabla_{\theta} f(x;\theta) at current parameter \theta using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.
Parameters
x:torch.Tensor- input data
(batch, input_shape)on compatible device with model. enable_backprop:bool, default= False- whether to enable backprop through the Js and f w.r.t. x
Returns
Js:torch.Tensor- Jacobians
(batch, parameters, outputs) f:torch.Tensor- output function
(batch, outputs)
def gradients(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], y: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]-
Compute gradients \nabla_\theta \ell(f(x;\theta, y) at current parameter \theta using Backpack's BatchGrad. Note that BackPACK doesn't play well with torch.func, so this method has to be overridden.
Parameters
x:torch.Tensor- input data
(batch, input_shape)on compatible device with model. y:torch.Tensor
Returns
Gs:torch.Tensor- gradients
(batch, parameters) loss:torch.Tensor
Inherited members
class BackPackGGN (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', stochastic: bool = False)-
Implementation of the
GGNInterfaceusing Backpack.Ancestors
Inherited members
class BackPackEF (model: nn.Module, likelihood: Likelihood | str, last_layer: bool = False, subnetwork_indices: torch.LongTensor | None = None, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels')-
Implementation of
EFInterfaceusing Backpack.Ancestors
Inherited members