Module laplace.baselaplace
Classes
class BaseLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)-
Baseclass for all Laplace approximations in this library.
Parameters
model:torch.nn.Modulelikelihood:Likelihoodorstr in {'classification', 'regression', 'reward_modeling'}- determines the log likelihood Hessian approximation.
In the case of 'reward_modeling', it fits Laplace using the classification likelihood,
then does prediction as in regression likelihood. The model needs to be defined accordingly:
The forward pass during training takes
x.shape == (batch_size, 2, dim)withy.shape = (batch_size,). Meanwhile, during evaluationx.shape == (batch_size, dim). Note that 'reward_modeling' only supportsKronLaplaceandDiagLaplace. sigma_noise:torch.Tensororfloat, default=1- observation noise for the regression setting; must be 1 for classification
prior_precision:torch.Tensororfloat, default=1- prior precision of a Gaussian prior (= weight decay); can be scalar, per-layer, or diagonal in the most general case
prior_mean:torch.Tensororfloat, default=0- prior mean of a Gaussian prior, useful for continual learning
temperature:float, default=1- temperature of the likelihood; lower temperature leads to more concentrated posterior and vice versa.
enable_backprop:bool, default=False- whether to enable backprop to the input
xthrough the Laplace predictive. Useful for e.g. Bayesian optimization. dict_key_x:str, default='input_ids'- The dictionary key under which the input tensor
xis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models. dict_key_y:str, default='labels'- The dictionary key under which the target tensor
yis stored. Only has effect when the model takes aMutableMappingas the input. Useful for Huggingface LLM models. backend:subclassesofCurvatureInterface- backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None.
backend_kwargs:dict, default=None- arguments passed to the backend on initialization, for example to set the number of MC samples for stochastic approximations.
asdl_fisher_kwargs:dict, default=None- arguments passed to the ASDL backend specifically on initialization.
Subclasses
Instance variables
var backend : CurvatureInterfacevar log_likelihood : torch.Tensor-
Compute log likelihood on the training data after
.fit()has been called. The log likelihood is computed on-demand based on the loss and, for example, the observation noise which makes it differentiable in the latter for iterative updates.Returns
log_likelihood:torch.Tensor
var prior_precision_diag : torch.Tensor-
Obtain the diagonal prior precision p_0 constructed from either a scalar, layer-wise, or diagonal prior precision.
Returns
prior_precision_diag:torch.Tensor
var prior_mean : torch.Tensorvar prior_precision : torch.Tensorvar sigma_noise : torch.Tensor
Methods
def fit(self, train_loader: DataLoader) ‑> Nonedef log_marginal_likelihood(self, prior_precision: torch.Tensor | None = None, sigma_noise: torch.Tensor | None = None) ‑> torch.Tensordef predictive(self, x: torch.Tensor, pred_type: PredType | str, link_approx: LinkApprox | str, n_samples: int) ‑> torch.Tensor | tuple[torch.Tensor, torch.Tensor]def optimize_prior_precision(self, pred_type: PredType | str, method: TuningMethod | str = TuningMethod.MARGLIK, n_steps: int = 100, lr: float = 0.1, init_prior_prec: float | torch.Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.DIAG, val_loader: DataLoader | None = None, loss: torchmetrics.Metric | Callable[[torch.Tensor], torch.Tensor | float] | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, progress_bar: bool = False) ‑> None-
Optimize the prior precision post-hoc using the
methodspecified by the user.Parameters
pred_type:PredTypeorstr in {'glm', 'nn'}- type of posterior predictive, linearized GLM predictive or neural network sampling predictiv. The GLM predictive is consistent with the curvature approximations used here.
method:TuningMethodorstr in {'marglik', 'gridsearch'}, default=PredType.MARGLIK- specifies how the prior precision should be optimized.
n_steps:int, default=100- the number of gradient descent steps to take.
lr:float, default=1e-1- the learning rate to use for gradient descent.
init_prior_prec:floatortensor, default=1.0- initial prior precision before the first optimization step.
prior_structure:PriorStructureorstr in {'scalar', 'layerwise', 'diag'}, default=PriorStructure.SCALAR- if init_prior_prec is scalar, the prior precision is optimized with this structure. otherwise, the structure of init_prior_prec is maintained.
val_loader:torch.data.utils.DataLoader, default=None- DataLoader for the validation set; each iterate is a training batch (X, y).
loss:callableortorchmetrics.Metric, default=None- loss function to use for CV. If callable, the loss is computed offline (memory intensive).
If torchmetrics.Metric, running loss is computed (efficient). The default
depends on the likelihood:
RunningNLLMetric()for classification and reward modeling, runningMeanSquaredError()for regression. log_prior_prec_min:float, default=-4- lower bound of gridsearch interval.
log_prior_prec_max:float, default=4- upper bound of gridsearch interval.
grid_size:int, default=100- number of values to consider inside the gridsearch interval.
link_approx:LinkApproxorstr in {'mc', 'probit', 'bridge'}, default=LinkApprox.PROBIT- how to approximate the classification link function for the
'glm'. Forpred_type='nn', only'mc'is possible. n_samples:int, default=100- number of samples for
link_approx='mc'. verbose:bool, default=False- if true, the optimized prior precision will be printed (can be a large tensor if the prior has a diagonal covariance).
progress_bar:bool, default=False- whether to show a progress bar; updated at every batch-Hessian computation.
Useful for very large model and large amount of data, esp. when
subset_of_weights='all'.
class ParametricLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)-
Parametric Laplace class.
Subclasses need to specify how the Hessian approximation is initialized, how to add up curvature over training data, how to sample from the Laplace approximation, and how to compute the functional variance.
A Laplace approximation is represented by a MAP which is given by the
modelparameter and a posterior precision or covariance specifying a Gaussian distribution \mathcal{N}(\theta_{MAP}, P^{-1}). The goal of this class is to compute the posterior precision P which sums as P = \sum_{n=1}^N \nabla^2_\theta \log p(\mathcal{D}_n \mid \theta) \vert_{\theta_{MAP}} + \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}}. Every subclass implements different approximations to the log likelihood Hessians, for example, a diagonal one. The prior is assumed to be Gaussian and therefore we have a simple form for \nabla^2_\theta \log p(\theta) \vert_{\theta_{MAP}} = P_0 . In particular, we assume a scalar, layer-wise, or diagonal prior precision so that in all cases P_0 = \textrm{diag}(p_0) and the structure of p_0 can be varied.Ancestors
Subclasses
Instance variables
var scatter : torch.Tensor-
Computes the scatter, a term of the log marginal likelihood that corresponds to L-2 regularization:
scatter= (\theta_{MAP} - \mu_0)^{T} P_0 (\theta_{MAP} - \mu_0) .Returns
scatter:torch.Tensor
var log_det_prior_precision : torch.Tensor-
Compute log determinant of the prior precision \log \det P_0
Returns
log_det:torch.Tensor
var log_det_posterior_precision : torch.Tensor-
Compute log determinant of the posterior precision \log \det P which depends on the subclasses structure used for the Hessian approximation.
Returns
log_det:torch.Tensor
var log_det_ratio : torch.Tensor-
Compute the log determinant ratio, a part of the log marginal likelihood. \log \frac{\det P}{\det P_0} = \log \det P - \log \det P_0
Returns
log_det_ratio:torch.Tensor
var posterior_precision : torch.Tensor-
Compute or return the posterior precision P.
Returns
posterior_prec:torch.Tensor
Methods
def fit(self, train_loader: DataLoader, override: bool = True, progress_bar: bool = False) ‑> None-
Fit the local Laplace approximation at the parameters of the model.
Parameters
train_loader:torch.data.utils.DataLoader- each iterate is a training batch, either
(X, y)tensors or a dict-like object containing keys as expressed byself.dict_key_xandself.dict_key_y.train_loader.datasetneeds to be set to access N, size of the data set. override:bool, default=True- whether to initialize H, loss, and n_data again; setting to False is useful for online learning settings to accumulate a sequential posterior approximation.
progress_bar:bool, default=False- whether to show a progress bar; updated at every batch-Hessian computation.
Useful for very large model and large amount of data, esp. when
subset_of_weights='all'.
def square_norm(self, value) ‑> torch.Tensor-
Compute the square norm under post. Precision with
value-self.meanas 𝛥: \Delta^ op P \Delta Returns
square_form
def log_prob(self, value: torch.Tensor, normalized: bool = True) ‑> torch.Tensor-
Compute the log probability under the (current) Laplace approximation.
Parameters
value:torch.Tensornormalized:bool, default=True- whether to return log of a properly normalized Gaussian or just the
terms that depend on
value.
Returns
log_prob:torch.Tensor
def log_marginal_likelihood(self, prior_precision: torch.Tensor | None = None, sigma_noise: torch.Tensor | None = None) ‑> torch.Tensor-
Compute the Laplace approximation to the log marginal likelihood subject to specific Hessian approximations that subclasses implement. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in
prior_precisionandsigma_noiseif these have gradients enabled. By passingprior_precisionorsigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.Parameters
prior_precision:torch.Tensor, optional- prior precision if should be changed from current
prior_precisionvalue sigma_noise:torch.Tensor, optional- observation noise standard deviation if should be changed
Returns
log_marglik:torch.Tensor
def predictive_samples(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: torch.Generator | None = None) ‑> torch.Tensor-
Sample from the posterior predictive on input data
x. Can be used, for example, for Thompson sampling.Parameters
x:torch.TensororMutableMapping- input data
(batch_size, input_shape) pred_type:{'glm', 'nn'}, default='glm'- type of posterior predictive, linearized GLM predictive or neural network sampling predictive. The GLM predictive is consistent with the curvature approximations used here.
n_samples:int- number of samples
diagonal_output:bool- whether to use a diagonalized glm posterior predictive on the outputs.
Only applies when
pred_type='glm'. generator:torch.Generator, optional- random number generator to control the samples (if sampling used)
Returns
samples:torch.Tensor- samples
(n_samples, batch_size, output_shape)
def functional_variance(self, Js: torch.Tensor) ‑> torch.Tensor-
Compute functional variance for the
'glm'predictive:f_var[i] = Js[i] @ P.inv() @ Js[i].T, which is a output x output predictive covariance matrix. Mathematically, we have for a single Jacobian \mathcal{J} = \nabla_\theta f(x;\theta)\vert_{\theta_{MAP}} the output covariance matrix \mathcal{J} P^{-1} \mathcal{J}^T .Parameters
Js:torch.Tensor- Jacobians of model output wrt parameters
(batch, outputs, parameters)
Returns
f_var:torch.Tensor- output covariance
(batch, outputs, outputs)
def functional_covariance(self, Js: torch.Tensor) ‑> torch.Tensor-
Compute functional covariance for the
'glm'predictive:f_cov = Js @ P.inv() @ Js.T, which is a batchoutput x batchoutput predictive covariance matrix.This emulates the GP posterior covariance N([f(x1), …,f(xm)], Cov[f(x1), …, f(xm)]). Useful for joint predictions, such as in batched Bayesian optimization.
Parameters
Js:torch.Tensor- Jacobians of model output wrt parameters
(batch*outputs, parameters)
Returns
f_cov:torch.Tensor- output covariance
(batch*outputs, batch*outputs)
def sample(self, n_samples: int = 100, generator: torch.Generator | None = None) ‑> torch.Tensor-
Sample from the Laplace posterior approximation, i.e., \theta \sim \mathcal{N}(\theta_{MAP}, P^{-1}).
Parameters
n_samples:int, default=100- number of samples
generator:torch.Generator, optional- random number generator to control the samples
Returns
samples:torch.Tensor
def state_dict(self) ‑> dict[str, typing.Any]def load_state_dict(self, state_dict: dict[str, Any]) ‑> None
Inherited members
class FunctionalLaplace (model: nn.Module, likelihood: Likelihood | str, n_subset: int, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x='inputs_id', dict_key_y='labels', backend: type[CurvatureInterface] | None = laplace.curvature.backpack.BackPackGGN, backend_kwargs: dict[str, Any] | None = None, independent_outputs: bool = False, seed: int = 0)-
Applying the GGN (Generalized Gauss-Newton) approximation for the Hessian in the Laplace approximation of the posterior turns the underlying probabilistic model from a BNN into a GLM (generalized linear model). This GLM (in the weight space) is equivalent to a GP (in the function space), see Approximate Inference Turns Deep Networks into Gaussian Processes (Khan et al., 2019)
This class implements the (approximate) GP inference through which we obtain the desired quantities (posterior predictive, marginal log-likelihood). See Improving predictions of Bayesian neural nets via local linearization (Immer et al., 2021) for more details.
Note that for
likelihood='classification', we approximate L_{NN} with a diagonal matrix ( L_{NN} is a block-diagonal matrix, where blocks represent Hessians of per-data-point log-likelihood w.r.t. neural network output f , See Appendix A.2.1 for exact definition). We resort to such an approximation because of the (possible) errors found in Laplace approximation for multiclass GP classification in Chapter 3.5 of R&W 2006 GP book, see the question here for more details. Alternatively, one could also resort to one-vs-one or one-vs-rest implementations for multiclass classification, however, that is not (yet) supported here.Parameters
num_data:int- number of data points for Subset-of-Data (SOD) approximate GP inference.
diagonal_kernel:bool- GP kernel here is product of Jacobians, which results in a C \times C matrix where C is the output
dimension. If
diagonal_kernel=True, only a diagonal of a GP kernel is used. This is (somewhat) equivalent to assuming independent GPs across output channels.
See
BaseLaplaceclass for the full interface.Ancestors
Subclasses
Instance variables
var gp_kernel_prior_variancevar log_det_ratio : torch.Tensor-
Computes log determinant term in GP marginal likelihood
For
classificationwe use eq. (3.44) from Chapter 3.5 from GP book R&W 2006 with (note that we always use diagonal approximation D of the Hessian of log likelihood w.r.t. f):log determinant term := \log | I + D^{1/2}K D^{1/2} |
For
regression, we use "standard" GP marginal likelihood:log determinant term := \log | K + \sigma_2 I |
var scatter : torch.Tensor-
Compute scatter term in GP log marginal likelihood.
For
classificationwe use eq. (3.44) from Chapter 3.5 from GP book R&W 2006 with \hat{f} = f :scatter term := f K^{-1} f^{T}
For
regression, we use "standard" GP marginal likelihood:scatter term := (y - m)K^{-1}(y -m )^T , where m is the mean of the GP prior, which in our case corresponds to m := f + J (\theta - \theta_{MAP})
var prior_precision
Methods
def fit(self, train_loader: DataLoader | MutableMapping, progress_bar: bool = False)-
Fit the Laplace approximation of a GP posterior.
Parameters
train_loader:torch.data.utils.DataLoadertrain_loader.datasetneeds to be set to access N, size of the data settrain_loader.batch_sizeneeds to be set to access b batch_sizeprogress_bar:bool- whether to show a progress bar during the fitting process.
def predictive_samples(self, x: torch.Tensor | MutableMapping[str, torch.Tensor | Any], pred_type: PredType | str = PredType.GLM, n_samples: int = 100, diagonal_output: bool = False, generator: torch.Generator | None = None) ‑> torch.Tensor-
Sample from the posterior predictive on input data
x. Can be used, for example, for Thompson sampling.Parameters
x:torch.TensororMutableMapping- input data
(batch_size, input_shape) pred_type:{'glm'}, default='glm'- type of posterior predictive, linearized GLM predictive.
n_samples:int- number of samples
diagonal_output:bool- whether to use a diagonalized glm posterior predictive on the outputs.
Only applies when
pred_type='glm'. generator:torch.Generator, optional- random number generator to control the samples (if sampling used)
Returns
samples:torch.Tensor- samples
(n_samples, batch_size, output_shape)
def functional_variance(self, Js_star: torch.Tensor) ‑> torch.Tensor-
GP posterior variance:
k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}
Parameters
Js_star:torch.Tensorofshape (N*, C, P)- Jacobians of test data points
Returns
f_var:torch.Tensorofshape (N*,C, C)- Contains the posterior variances of N* testing points.
def functional_covariance(self, Js_star: torch.Tensor) ‑> torch.Tensor-
GP posterior covariance:
k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}
Parameters
Js_star:torch.Tensorofshape (N*, C, P)- Jacobians of test data points
Returns
f_var:torch.Tensorofshape (N*xC, N*xC)- Contains the posterior covariances of N* testing points.
def optimize_prior_precision(self, pred_type: PredType | str = PredType.GP, method: TuningMethod | str = TuningMethod.MARGLIK, n_steps: int = 100, lr: float = 0.1, init_prior_prec: float | torch.Tensor = 1.0, prior_structure: PriorStructure | str = PriorStructure.SCALAR, val_loader: DataLoader | None = None, loss: torchmetrics.Metric | Callable[[torch.Tensor], torch.Tensor | float] | None = None, log_prior_prec_min: float = -4, log_prior_prec_max: float = 4, grid_size: int = 100, link_approx: LinkApprox | str = LinkApprox.PROBIT, n_samples: int = 100, verbose: bool = False, progress_bar: bool = False) ‑> None-
optimize_prior_precision_basefromBaseLaplacewithpred_type='gp' def log_marginal_likelihood(self, prior_precision: torch.Tensor | None = None, sigma_noise: torch.Tensor | None = None) ‑> torch.Tensor-
Compute the Laplace approximation to the log marginal likelihood. Requires that the Laplace approximation has been fit before. The resulting torch.Tensor is differentiable in
prior_precisionandsigma_noiseif these have gradients enabled. By passingprior_precisionorsigma_noise, the current value is overwritten. This is useful for iterating on the log marginal likelihood.Parameters
prior_precision:torch.Tensor, optional- prior precision if should be changed from current
prior_precisionvalue sigma_noise:torch.Tensor, optional- observation noise standard deviation if should be changed
Returns
log_marglik:torch.Tensor
def state_dict(self) ‑> dictdef load_state_dict(self, state_dict: dict)
Inherited members
class FullLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'input_ids', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None)-
Laplace approximation with full, i.e., dense, log likelihood Hessian approximation and hence posterior precision. Based on the chosen
backendparameter, the full approximation can be, for example, a generalized Gauss-Newton matrix. Mathematically, we have P \in \mathbb{R}^{P \times P}. SeeBaseLaplacefor the full interface.Ancestors
Subclasses
Instance variables
var posterior_scale : torch.Tensor-
Posterior scale (square root of the covariance), i.e., P^{-\frac{1}{2}}.
Returns
scale:torch.tensor(parameters, parameters)
var posterior_covariance : torch.Tensor-
Posterior covariance, i.e., P^{-1}.
Returns
covariance:torch.tensor(parameters, parameters)
var posterior_precision : torch.Tensor-
Posterior precision P.
Returns
precision:torch.tensor(parameters, parameters)
Inherited members
class KronLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, damping: bool = False, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)-
Laplace approximation with Kronecker factored log likelihood Hessian approximation and hence posterior precision. Mathematically, we have for each parameter group, e.g., torch.nn.Module, that \P\approx Q \otimes H. See
BaseLaplacefor the full interface and seeKronandKronDecomposedfor the structure of the Kronecker factors.Kronis used to aggregate factors by summing up andKronDecomposedis used to add the prior, a Hessian factor (e.g. temperature), and computing posterior covariances, marginal likelihood, etc. Damping can be enabled by settingdamping=True.Ancestors
Subclasses
Instance variables
var posterior_precision : KronDecomposedvar prior_precision : torch.Tensor
Methods
def state_dict(self) ‑> dict[str, typing.Any]def load_state_dict(self, state_dict: dict[str, Any])
Inherited members
class DiagLaplace (model: nn.Module, likelihood: Likelihood | str, sigma_noise: float | torch.Tensor = 1.0, prior_precision: float | torch.Tensor = 1.0, prior_mean: float | torch.Tensor = 0.0, temperature: float = 1.0, enable_backprop: bool = False, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend: type[CurvatureInterface] | None = None, backend_kwargs: dict[str, Any] | None = None, asdl_fisher_kwargs: dict[str, Any] | None = None)-
Laplace approximation with diagonal log likelihood Hessian approximation and hence posterior precision. Mathematically, we have P \approx \textrm{diag}(P). See
BaseLaplacefor the full interface.Ancestors
Subclasses
Instance variables
var posterior_precision : torch.Tensor-
Diagonal posterior precision p.
Returns
precision:torch.tensor(parameters)
var posterior_scale : torch.Tensor-
Diagonal posterior scale \sqrt{p^{-1}}.
Returns
precision:torch.tensor(parameters)
var posterior_variance : torch.Tensor-
Diagonal posterior variance p^{-1}.
Returns
precision:torch.tensor(parameters)
Inherited members
class LowRankLaplace (model: nn.Module, likelihood: Likelihood | str, backend: type[CurvatureInterface] = laplace.curvature.curvature.CurvatureInterface, sigma_noise: float | torch.Tensor = 1, prior_precision: float | torch.Tensor = 1, prior_mean: float | torch.Tensor = 0, temperature: float = 1, enable_backprop: bool = False, dict_key_x: str = 'inputs_id', dict_key_y: str = 'labels', backend_kwargs: dict[str, Any] | None = None)-
Laplace approximation with low-rank log likelihood Hessian (approximation). The low-rank matrix is represented by an eigendecomposition (vecs, values). Based on the chosen
backend, either a true Hessian or, for example, GGN approximation could be used. The posterior precision is computed as P = V diag(l) V^T + P_0. To sample, compute the functional variance, and log determinant, algebraic tricks are usedto reduce the costs of inversion to the that of a K imes K matrix if we have a rank of K.Note that only
AsdfghjklHessianbackend is supported. Install it via: pip install git+https://git@github.com/wiseodd/asdl@asdfghjklSee
BaseLaplacefor the full interface.Ancestors
Instance variables
var V : torch.Tensorvar Kinv : torch.Tensorvar posterior_precision : tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]-
Return correctly scaled posterior precision that would be constructed as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.
Returns
H:tuple(eigenvectors, eigenvalues)- scaled self.H with temperature and loss factors.
prior_precision_diag:torch.Tensor- diagonal prior precision shape
parametersto be added to H.
Inherited members