Module laplace.utils.utils
Functions
def get_nll(out_dist: torch.Tensor, targets: torch.Tensor) ‑> torch.Tensordef validate(laplace: BaseLaplace, val_loader: DataLoader, loss: torchmetrics.Metric | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], pred_type: PredType | str = PredType.GLM, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, dict_key_y: str = 'labels') ‑> floatdef parameters_per_layer(model: nn.Module) ‑> list[int]-
Get number of parameters per layer.
Parameters
model:torch.nn.Module
Returns
params_per_layer:list[int]
def invsqrt_precision(M: torch.Tensor) ‑> torch.Tensor-
Compute
M^{-0.5}as a tridiagonal matrix.Parameters
M:torch.Tensor
Returns
M_invsqrt:torch.Tensor
def kron(t1: torch.Tensor, t2: torch.Tensor) ‑> torch.Tensor-
Computes the Kronecker product between two tensors.
Parameters
t1:torch.Tensort2:torch.Tensor
Returns
kron_product:torch.Tensor
def diagonal_add_scalar(X: torch.Tensor, value: torch.Tensor) ‑> torch.Tensor-
Add scalar value
valueto diagonal ofX.Parameters
X:torch.Tensorvalue:torch.Tensororfloat
Returns
X_add_scalar:torch.Tensor
def symeig(M: torch.Tensor) ‑> tuple[torch.Tensor, torch.Tensor]-
Symetric eigendecomposition avoiding failure cases by adding and removing jitter to the diagonal.
Parameters
M:torch.Tensor
Returns
L:torch.Tensor- eigenvalues
W:torch.Tensor- eigenvectors
def block_diag(blocks: list[torch.Tensor]) ‑> torch.Tensor-
Compose block-diagonal matrix of individual blocks.
Parameters
blocks:list[torch.Tensor]
Returns
M:torch.Tensor
def expand_prior_precision(prior_prec: torch.Tensor, model: nn.Module) ‑> torch.Tensor-
Expand prior precision to match the shape of the model parameters.
Parameters
prior_prec:torch.Tensor 1-dimensional- prior precision
model:torch.nn.Module- torch model with parameters that are regularized by prior_prec
Returns
expanded_prior_prec:torch.Tensor- expanded prior precision has the same shape as model parameters
Classes
class SoDSampler (N, M, seed: int = 0)-
Base class for all Samplers.
Every Sampler subclass has to provide an :meth:
__iter__method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and a :meth:__len__method that returns the length of the returned iterators.Args
data_source:Dataset- This argument is not used and will be removed in 2.2.0. You may still have custom implementation that utilizes it.
Example
>>> # xdoctest: +SKIP >>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist()Note: The :meth:
__len__method isn't strictly required by:class:
~torch.utils.data.DataLoader, but is expected in any calculation involving the length of a :class:~torch.utils.data.DataLoader.Ancestors
- torch.utils.data.sampler.Sampler
- typing.Generic