Module laplace.utils.matrix
Classes
class Kron (kfacs: list[tuple[torch.Tensor] | torch.Tensor])-
Kronecker factored approximate curvature representation for a corresponding neural network. Each element in
kfacsis either a tuple or single matrix. A tuple represents two Kronecker factors Q, and H and a single element is just a full block Hessian approximation.Parameters
kfacs:list[Iterable[torch.Tensor] | torch.Tensor]- each element in the list is a tuple of two Kronecker factors Q, H or a single matrix approximating the Hessian (in case of bias, for example)
Static methods
def init_from_model(model: nn.Module | Iterable[nn.Parameter], device: torch.device) ‑> Kron-
Initialize Kronecker factors based on a models architecture.
Parameters
model:nn.Moduleoriterableofparameters, e.g. model.parameters()device:torch.device
Returns
kron:Kron
Methods
def decompose(self, damping: bool = False) ‑> KronDecomposed-
Eigendecompose Kronecker factors and turn into
KronDecomposed. Parameters
damping:bool- use damping
Returns
kron_decomposed:KronDecomposed
def bmm(self, W: torch.Tensor, exponent: float = 1) ‑> torch.Tensor-
Batched matrix multiplication with the Kronecker factors. If Kron is
H, we computeH @ W. This is useful for computing the predictive or a regularization based on Kronecker factors as in continual learning.Parameters
W:torch.Tensor- matrix
(batch, classes, params) exponent:float, default=1- only can be
1for Kron, requiresKronDecomposedfor other exponent values of the Kronecker factors.
Returns
SW:torch.Tensor- result
(batch, classes, params)
def logdet(self) ‑> torch.Tensor-
Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation.
Returns
logdet:torch.Tensor
def diag(self) ‑> torch.Tensor-
Extract diagonal of the entire Kronecker factorization.
Returns
diag:torch.Tensor
def to_matrix(self) ‑> torch.Tensor-
Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.
Returns
block_diag:torch.Tensor
class KronDecomposed (eigenvectors: list[tuple[torch.Tensor]], eigenvalues: list[tuple[torch.Tensor]], deltas: torch.Tensor | None = None, damping: bool = False)-
Decomposed Kronecker factored approximate curvature representation for a corresponding neural network. Each matrix in
Kronis decomposed to obtainKronDecomposed. Front-loading decomposition allows cheap repeated computation of inverses and log determinants. In contrast toKron, we can add scalar or layerwise scalars but we cannot add otherKronorKronDecomposedanymore.Parameters
eigenvectors:list[Tuple[torch.Tensor]]- eigenvectors corresponding to matrices in a corresponding
Kron eigenvalues:list[Tuple[torch.Tensor]]- eigenvalues corresponding to matrices in a corresponding
Kron deltas:torch.Tensor- addend for each group of Kronecker factors representing, for example, a prior precision
dampen:bool, default=False- use dampen approximation mixing prior and Kron partially multiplicatively
Methods
def detach(self) ‑> KronDecomposeddef logdet(self) ‑> torch.Tensor-
Compute log determinant of the Kronecker factors and sums them up. This corresponds to the log determinant of the entire Hessian approximation. In contrast to
Kron.logdet(), additivedeltascorresponding to prior precisions are added.Returns
logdet:torch.Tensor
def inv_square_form(self, W: torch.Tensor) ‑> torch.Tensordef bmm(self, W: torch.Tensor, exponent: float = -1) ‑> torch.Tensor-
Batched matrix multiplication with the decomposed Kronecker factors. This is useful for computing the predictive or a regularization loss. Compared to
Kron.bmm(), a prior can be added here in form ofdeltasand the exponent can be other than just 1. Computes H^{exponent} W.Parameters
W:torch.Tensor- matrix
(batch, classes, params) exponent:float, default=1
Returns
SW:torch.Tensor- result
(batch, classes, params)
def diag(self, exponent: float = 1) ‑> torch.Tensor-
Extract diagonal of the entire decomposed Kronecker factorization.
Parameters
exponent:float, default=1- exponent of the Kronecker factorization
Returns
diag:torch.Tensor
def to_matrix(self, exponent: float = 1) ‑> torch.Tensor-
Make the Kronecker factorization dense by computing the kronecker product. Warning: this should only be used for testing purposes as it will allocate large amounts of memory for big architectures.
Parameters
exponent:float, default=1- exponent of the Kronecker factorization
Returns
block_diag:torch.Tensor