All functions and types

Exported functions

LaplaceRedux.LaplaceType
Laplace

Concrete type for Laplace approximation. This type is a subtype of AbstractLaplace and is used to store all the necessary information for a Laplace approximation.

Fields

  • model::Flux.Chain: The model to be approximated.
  • likelihood::Symbol: The likelihood function to be used.
  • est_params::EstimationParams: The estimation parameters.
  • prior::Prior: The parameters defining prior distribution.
  • posterior::Posterior: The posterior distribution.
source
LaplaceRedux.LaplaceMethod
Laplace(model::Any; likelihood::Symbol, kwargs...)

Outer constructor for Laplace approximation. This function constructs a Laplace object from a given model and likelihood function.

Arguments

  • model::Any: The model to be approximated (a Flux.Chain).
  • likelihood::Symbol: The likelihood function to be used. Possible values are :regression and :classification.

Keyword Arguments

See LaplaceParams for a description of the keyword arguments.

Returns

  • la::Laplace: The Laplace object.

Examples

using Flux, LaplaceRedux
nn = Chain(Dense(2,1))
la = Laplace(nn, likelihood=:regression)
source
LaplaceRedux.LaplaceClassifierType
LaplaceClassifier

A model type for constructing a laplace classifier, based on LaplaceRedux.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux

Do model = LaplaceClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in LaplaceClassifier(model=...).

LaplaceClassifier implements the Laplace Redux – Effortless Bayesian Deep Learning, originally published in Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): "Laplace Redux – Effortless Bayesian Deep Learning.", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103 for classification models.

Training data

In MLJ or MLJBase, given a dataset X,y and a Flux_Chain adapted to the dataset, pass the chain to the model

laplace_model = LaplaceClassifier(model = Flux_Chain,kwargs...)

then bind an instance laplace_model to data with

mach = machine(laplace_model, X, y)

where

  • X: any table of input features (eg, a DataFrame) whose columns each have one of the following element scitypes: Continuous, Count, or <:OrderedFactor; check column scitypes with schema(X)

  • y: is the target, which can be any AbstractVector whose element scitype is <:OrderedFactor or <:Multiclass; check the scitype with scitype(y)

Train the machine using fit!(mach, rows=...).

Hyperparameters (format: name-type-default value-restrictions)

  • model::Union{Flux.Chain,Nothing} = nothing: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each.

  • flux_loss = Flux.Losses.logitcrossentropy : a Flux loss function

  • optimiser = Adam() a Flux optimiser

  • epochs::Integer = 1000::(_ > 0): the number of training epochs.

  • batch_size::Integer = 32::(_ > 0): the batch size.

  • subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)): the subset of weights to use, either :all, :last_layer, or :subnetwork.

  • subnetwork_indices = nothing: the indices of the subnetworks.

  • hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal)): the structure of the Hessian matrix, either :full or :diagonal.

  • backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)): the backend to use, either :GGN or :EmpiricalFisher.

  • observational_noise (alias σ)::Float64 = 1.0: the standard deviation of the prior distribution.

  • prior_mean (alias μ₀)::Float64 = 0.0: the mean of the prior distribution.

  • prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing: the covariance matrix of the prior distribution.

  • fit_prior_nsteps::Int = 100::(_ > 0): the number of steps used to fit the priors.

  • link_approx::Symbol = :probit::(_ in (:probit, :plugin)): the approximation to adopt to compute the probabilities.

Operations

  • predict(mach, Xnew): return predictions of the target given features Xnew having the same scitype as X above. Predictions are probabilistic, but uncalibrated.

  • predict_mode(mach, Xnew): instead return the mode of each prediction above.

Fitted parameters

The fields of fitted_params(mach) are:

  • mean: The mean of the posterior distribution.

  • H: The Hessian of the posterior distribution.

  • P: The precision matrix of the posterior distribution.

  • cov_matrix: The covariance matrix of the posterior distribution.

  • n_data: The number of data points.

  • n_params: The number of parameters.

  • n_out: The number of outputs.

  • loss: The loss value of the posterior distribution.

Report

The fields of report(mach) are:

  • loss_history: an array containing the total loss per epoch.

Accessor functions

  • training_losses(mach): return the loss history from report

Examples

using MLJ
LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux

X, y = @load_iris

# Define the Flux Chain model
using Flux
model = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
    Dense(10, 3)
)

#Define the LaplaceClassifier
model = LaplaceClassifier(model=model)

mach = machine(model, X, y) |> fit!

Xnew = (sepal_length = [6.4, 7.2, 7.4],
        sepal_width = [2.8, 3.0, 2.8],
        petal_length = [5.6, 5.8, 6.1],
        petal_width = [2.1, 1.6, 1.9],)
yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew)   # point predictions
training_losses(mach)      # loss history per epoch
pdf.(yhat, "virginica")    # probabilities for the "verginica" class
fitted_params(mach)        # NamedTuple with the fitted params of Laplace

See also LaplaceRedux.jl.

source
LaplaceRedux.LaplaceRegressorType
LaplaceRegressor

A model type for constructing a laplace regressor, based on LaplaceRedux.jl, and implementing the MLJ model interface.

From MLJ, the type can be imported using

LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux

Do model = LaplaceRegressor() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in LaplaceRegressor(model=...).

LaplaceRegressor implements the Laplace Redux – Effortless Bayesian Deep Learning, originally published in Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): "Laplace Redux – Effortless Bayesian Deep Learning.", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103 for regression models.

Training data

In MLJ or MLJBase, given a dataset X,y and a Flux_Chain adapted to the dataset, pass the chain to the model

laplace_model = LaplaceRegressor(model = Flux_Chain,kwargs...)

then bind an instance laplace_model to data with

mach = machine(laplace_model, X, y)

where

  • X: any table of input features (eg, a DataFrame) whose columns each have one of the following element scitypes: Continuous, Count, or <:OrderedFactor; check column scitypes with schema(X)

  • y: is the target, which can be any AbstractVector whose element scitype is <:Continuous; check the scitype with scitype(y)

Train the machine using fit!(mach, rows=...).

Hyperparameters (format: name-type-default value-restrictions)

  • model::Union{Flux.Chain,Nothing} = nothing: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each.

  • flux_loss = Flux.Losses.logitcrossentropy : a Flux loss function

  • optimiser = Adam() a Flux optimiser

  • epochs::Integer = 1000::(_ > 0): the number of training epochs.

  • batch_size::Integer = 32::(_ > 0): the batch size.

  • subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)): the subset of weights to use, either :all, :last_layer, or :subnetwork.

  • subnetwork_indices = nothing: the indices of the subnetworks.

  • hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal)): the structure of the Hessian matrix, either :full or :diagonal.

  • backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)): the backend to use, either :GGN or :EmpiricalFisher.

  • observational_noise (alias σ)::Float64 = 1.0: the standard deviation of the prior distribution.

  • prior_mean (alias μ₀)::Float64 = 0.0: the mean of the prior distribution.

  • prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing: the covariance matrix of the prior distribution.

  • fit_prior_nsteps::Int = 100::(_ > 0): the number of steps used to fit the priors.

Operations

  • predict(mach, Xnew): return predictions of the target given features Xnew having the same scitype as X above. Predictions are probabilistic, but uncalibrated.

  • predict_mode(mach, Xnew): instead return the mode of each prediction above.

Fitted parameters

The fields of fitted_params(mach) are:

  • mean: The mean of the posterior distribution.

  • H: The Hessian of the posterior distribution.

  • P: The precision matrix of the posterior distribution.

  • cov_matrix: The covariance matrix of the posterior distribution.

  • n_data: The number of data points.

  • n_params: The number of parameters.

  • n_out: The number of outputs.

  • loss: The loss value of the posterior distribution.

Report

The fields of report(mach) are:

  • loss_history: an array containing the total loss per epoch.

Accessor functions

  • training_losses(mach): return the loss history from report

Examples

using MLJ
using Flux
LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux
model = Chain(
    Dense(4, 10, relu),
    Dense(10, 10, relu),
    Dense(10, 1)
)
model = LaplaceRegressor(model=model)

X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) |> fit!

Xnew, _ = make_regression(3, 4; rng=123)
yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew)   # point predictions
training_losses(mach)      # loss history per epoch
fitted_params(mach)        # NamedTuple with the fitted params of Laplace

See also LaplaceRedux.jl.

source
LaplaceRedux.empirical_frequency_binary_classificationMethod
empirical_frequency_binary_classification(y_binary, distributions::Vector{Bernoulli{Float64}}; n_bins::Int=20)

FOR BINARY CLASSIFICATION MODELS.
Given a calibration dataset $(x_t, y_t)$ for $i ∈ {1,...,T}$ let $p_t= H(x_t)∈[0,1]$ be the forecasted probability.
We group the $p_t$ into intervals $I_j$ for $j= 1,2,...,m$ that form a partition of [0,1]. The function computes the observed average $p_j= T^-1_j ∑_{t:p_t ∈ I_j} y_j$ in each interval $I_j$.
Source: Kuleshov, Fenner, Ermon 2018

Inputs:
- y_binary: the array of outputs $y_t$ numerically coded: 1 for the target class, 0 for the null class.
- distributions: an array of Bernoulli distributions
- n_bins: number of equally spaced bins to use.

Outputs:
- num_p_per_interval: array with the number of probabilities falling within interval.
- emp_avg: array with the observed empirical average per interval.
- bin_centers: array with the centers of the bins.

source
LaplaceRedux.empirical_frequency_regressionMethod
empirical_frequency_regression(Y_cal, distributions::Distributions.Normal, n_bins=20)

Dispatched version for Normal distributions FOR REGRESSION MODELS.
Given a calibration dataset $(x_t, y_t)$ for $i ∈ {1,...,T}$ and an array of predicted distributions, the function calculates the empirical frequency

\[p^hat_j = {y_t|F_t(y_t)<= p_j, t= 1,....,T}/T,\]

where $T$ is the number of calibration points, $p_j$ is the confidence level and $F_t$ is the cumulative distribution function of the predicted distribution targeting $y_t$.
Source: Kuleshov, Fenner, Ermon 2018

Inputs:
- Y_cal: a vector of values $y_t$
- distributions:a Vector{Distributions.Normal{Float64}} of distributions stacked row-wise.
For example the output of LaplaceRedux.predict(la,Xcal).
- `n
bins`: number of equally spaced bins to use.
Outputs:
- counts: an array cointaining the empirical frequencies for each quantile interval.

source
LaplaceRedux.extract_mean_and_varianceMethod
extract_mean_and_variance(distr::Vector{Normal{<: AbstractFloat}})

Extract the mean and the variance of each distributions and return them in two separate lists.

Inputs: - distributions: a Vector of Normal distributions

Outputs: - means: the list of the means - variances: the list of the variances

source
LaplaceRedux.fit!Method
fit!(la::AbstractLaplace,data)

Fits the Laplace approximation for a data set. The function returns the number of observations (n_data) that were used to update the Laplace object. It does not return the updated Laplace object itself because the function modifies the input Laplace object in place (as denoted by the use of '!' in the function's name).

Examples

using Flux, LaplaceRedux
x, y = LaplaceRedux.Data.toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn)
fit!(la, data)
source
LaplaceRedux.glm_predictive_distributionMethod
glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)

Computes the linearized GLM predictive from neural network with a Laplace approximation to the posterior $p(\theta|\mathcal{D})=\mathcal{N}(\hat\theta,\Sigma)$. This is the distribution on network outputs given by $p(f(x)|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta})$. For the Bayesian predictive distribution, see predict.

Arguments

  • la::AbstractLaplace: A Laplace object.
  • X::AbstractArray: Input data.

Returns

  • normal_distr A normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.- normal_distr A normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.
  • fμ::AbstractArray: Mean of the predictive distribution. The output shape is column-major as in Flux.
  • fvar::AbstractArray: Variance of the predictive distribution. The output shape is column-major as in Flux.

Examples

using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
glm_predictive_distribution(la, hcat(x...))
source
LaplaceRedux.optimize_prior!Method
optimize_prior!(
    la::AbstractLaplace;
    n_steps::Int=100, lr::Real=1e-1,
    λinit::Union{Nothing,Real}=nothing,
    σinit::Union{Nothing,Real}=nothing
)

Optimize the prior precision post-hoc through Empirical Bayes (marginal log-likelihood maximization).

source
LaplaceRedux.posterior_covarianceFunction
posterior_covariance(la::AbstractLaplace, P=la.P)

Computes the posterior covariance $∑$ as the inverse of the posterior precision: $\Sigma=P^{-1}$.

source
LaplaceRedux.posterior_precisionFunction
posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.prior_precision_matrix)

Computes the posterior precision $P$ for a fitted Laplace Approximation as follows,

$P = \sum_{n=1}^N\nabla_{\theta}^2 \log p(\mathcal{D}_n|\theta)|_{\hat\theta} + \nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}$

where $\sum_{n=1}^N\nabla_{\theta}^2\log p(\mathcal{D}_n|\theta)|_{\hat\theta}=H$ is the Hessian and $\nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}=P_0$ is the prior precision and $\hat\theta$ is the MAP estimate.

source
LaplaceRedux.predictMethod
predict(
    la::AbstractLaplace,
    X::AbstractArray;
    link_approx=:probit,
    predict_proba::Bool=true,
    ret_distr::Bool=false,
)

Computes the Bayesian predictivie distribution from a neural network with a Laplace approximation to the posterior $p(\theta | \mathcal{D}) = \mathcal{N}(\hat\theta, \Sigma)$.

Arguments

  • la::AbstractLaplace: A Laplace object.
  • X::AbstractArray: Input data.
  • link_approx::Symbol=:probit: Link function approximation. Options are :probit and :plugin.
  • predict_proba::Bool=true: If true (default) apply a sigmoid or a softmax function to the output of the Flux model.
  • return_distr::Bool=false: if false (default), the function outputs either the direct output of the chain or pseudo-probabilities (if predict_proba=true). if true predict returns a probability distribution.

Returns

For classification tasks:

  1. If ret_distr is false, predict returns , i.e. the mean of the predictive distribution, which corresponds to the MAP estimate if the link function is set to :plugin, otherwise the probit approximation. The output shape is column-major as in Flux.
  2. If ret_distr is true, predict returns a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.

For regression tasks:

  1. If ret_distr is false, predict returns the mean and the variance of the predictive distribution. The output shape is column-major as in Flux.
  2. If ret_distr is true, predict returns the predictive posterior distribution, namely:

$p(y|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta} + \sigma^2 \mathbf{I})$

Examples

using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
predict(la, hcat(x...))
source
LaplaceRedux.rescale_stddevMethod
rescale_stddev(distr::Vector{Normal{T}}, s::T) where {T<:AbstractFloat}

Rescale the standard deviation of the Normal distributions received as argument and return a vector of rescaled Normal distributions. Inputs:
- distr: a Vector of Normal distributions - s: a scale factor of type T.

Outputs:
- Vector{Normal{T}}: a Vector of rescaled Normal distributions.

source
LaplaceRedux.sharpness_classificationMethod
sharpness_classification(y_binary,distributions::Distributions.Bernoulli)

dispatched for Bernoulli Distributions FOR BINARY CLASSIFICATION MODELS.
Assess the sharpness of the model by looking at the distribution of model predictions. When forecasts are sharp, most predictions are close to either 0 or 1
Source: Kuleshov, Fenner, Ermon 2018

Inputs:
- y_binary: the array of outputs $y_t$ numerically coded: 1 for the target class, 0 for the negative result.
- distributions: an array of Bernoulli distributions describing the probability of of the output belonging to the target class
Outputs:
- mean_class_one: a scalar that measure the average prediction for the target class
- mean_class_zero: a scalar that measure the average prediction for the null class

source
LaplaceRedux.sharpness_regressionMethod
sharpness_regression(distributions::Distributions.Normal)

Dispatched version for Normal distributions FOR REGRESSION MODELS.
Given a calibration dataset $(x_t, y_t)$ for $i ∈ {1,...,T}$ and an array of predicted distributions, the function calculates the sharpness of the predicted distributions, i.e., the average of the variances $\sigma^2(F_t)$ predicted by the forecaster for each $x_t$.
source: Kuleshov, Fenner, Ermon 2018

Inputs:
- distributions: an array of normal distributions $F(x_t)$ stacked row-wise.
Outputs:
- sharpness: a scalar that measure the level of sharpness of the regressor

source
LaplaceRedux.sigma_scalingMethod
sigma_scaling(distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat}) where T <: AbstractFloat

Compute the value of Σ that maximize the conditional log-likelihood:

\[ m ln(Σ) +1/2 * Σ^{-2} ∑_{i=1}^{i=m} || y_cal_i - ̄y_mean_i ||^2 / σ^2_i \]

where m is the number of elements in the calibration set (xcal,ycal).
Source: Laves,Ihler,Fast, Kahrs, Ortmaier,2020 Inputs:
- distr: a Vector of Normal distributions
- y_cal: a Vector of true results.

Outputs:
- sigma: the scalar that maximize the likelihood.

source
MLJModelInterface.predictMethod

function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)

Predicts the response for new data using a fitted Laplace model.

Arguments

  • m::LaplaceRegressor: The Laplace model.
  • fitresult: The result of the fitting procedure.
  • Xnew: The new data for which predictions are to be made.

Returns

for LaplaceRegressor:
- An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`.
for LaplaceClassifier:
- `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data.
source

Internal functions

LaplaceRedux.AbstractLaplaceMethod
(la::AbstractLaplace)(X::AbstractArray)

Calling a model with Laplace Approximation on an array of inputs is equivalent to explicitly calling the predict function.

source
LaplaceRedux.EstimationParamsType
EstimationParams

Container for the parameters of a Laplace approximation.

Fields

  • subset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.
  • subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.
  • hessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.
  • curvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.
source
LaplaceRedux.KronType

Kronecker-factored approximate curvature representation for a neural network model. Each element in kfacs represents two Kronecker factors (𝐆, 𝐀), such that the full block Hessian approximation would be approximated as 𝐀⊗𝐆.

source
LaplaceRedux.KronDecomposedType
KronDecomposed

Decomposed Kronecker-factored approximate curvature representation for a neural network model.

Decomposition is required to add the prior (diagonal matrix) to the posterior (KronDecomposed). It also has the benefits of reducing the costs for computation of inverses and log-determinants.

source
LaplaceRedux.LaplaceParamsType
LaplaceParams

Container for the parameters of a Laplace approximation.

Fields

  • subset_of_weights::Symbol: the subset of weights to consider. Possible values are :all, :last_layer, and :subnetwork.
  • subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values are nothing or a vector of vectors of integers.
  • hessian_structure::HessianStructure: the structure of the Hessian. Possible values are :full and :kron or a concrete subtype of HessianStructure.
  • backend::Symbol: the backend to use. Possible values are :GGN and :Fisher.
  • curvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values are nothing or a concrete subtype of CurvatureInterface.
  • observational_noise::Real: the observation noise
  • σ::Real: alias for observational_noise.
  • prior_mean::Real: the prior mean of the network parameters.
  • μ₀::Real: alias for prior_mean.
  • prio_precision::Real: the prior precision for the network parameters.
  • λ::Real: alias for prior_precision.
  • prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix for the network parameters.
  • P₀::Union{Nothing,AbstractMatrix,UniformScaling}: alias for prior_precision_matrix.
source
LaplaceRedux.PosteriorType
Posterior

Container for the results of a Laplace approximation.

Fields

  • posterior_mean::AbstractVector: the MAP estimate of the parameters
  • H::Union{AbstractArray,AbstractDecomposition,Nothing}: the Hessian matrix
  • P::Union{AbstractArray,AbstractDecomposition,Nothing}: the posterior precision matrix
  • posterior_covariance_matrix::Union{AbstractArray,Nothing}: the posterior covariance matrix
  • n_data::Union{Int,Nothing}: the number of data points
  • n_params::Union{Int,Nothing}: the number of parameters
  • n_out::Union{Int,Nothing}: the number of outputs
  • loss::Real: the loss value
source
LaplaceRedux.PriorType
Prior

Container for the prior parameters of a Laplace approximation.

Fields

  • observational_noise::Real: the observation noise
  • prior_mean::Real: the prior mean
  • prior_precision::Real: the prior precision
  • prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix
source
LaplaceRedux.PriorMethod
Prior(params::LaplaceParams)

Extracts the prior parameters from a LaplaceParams object.

source
Base.:*Method

Multiply by a scalar by changing the eigenvalues. Distribute the scalar along the factors of a block.

source
Base.:*Method

Kronecker-factored curvature scalar scaling.

source
Base.:+Method

Shift the factors by a diagonal (assumed uniform scaling)

source
Base.:+Method

Shift the factors by a scalar across the diagonal.

source
LaplaceRedux._H_factorMethod
_H_factor(la::AbstractLaplace)

Returns the factor σ⁻², where σ is used in the zero-centered Gaussian prior p(θ) = N(θ;0,σ²I)

source
LaplaceRedux._fit!Method
_fit!(la::Laplace, hessian_structure::FullHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)

Fit a Laplace approximation to the posterior distribution of a model using the full Hessian.

source
LaplaceRedux._fit!Method
_fit!(la::Laplace, hessian_structure::KronHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)

Fit a Laplace approximation to the posterior distribution of a model using the Kronecker-factored Hessian.

source
LaplaceRedux._weight_penaltyMethod
_weight_penalty(la::AbstractLaplace)

The weight penalty term is a regularization term used to prevent overfitting. Weight regularization methods such as weight decay introduce a penalty to the loss function when training a neural network to encourage the network to use small weights. Smaller weights in a neural network can result in a model that is more stable and less likely to overfit the training dataset, in turn having better performance when making a prediction on new data.

source
LaplaceRedux.approximateMethod
approximate(curvature::CurvatureInterface, hessian_structure::FullHessian, d::Tuple; batched::Bool=false)

Compute the full approximation, for either a single input-output datapoint or a batch of such.

source
LaplaceRedux.approximateMethod
approximate(curvature::CurvatureInterface, hessian_structure::KronHessian, data; batched::Bool=false)

Compute the eigendecomposed Kronecker-factored approximate curvature as the Fisher information matrix.

Note, since the network predictive distribution is used in a weighted sum, and the number of backward passes is linear in the number of target classes, e.g. 100 for CIFAR-100.

source
LaplaceRedux.clampMethod

Clamp eigenvalues in an eigendecomposition to be non-negative.

Since the Fisher information matrix is a positive-semidefinite by construction, the (near-zero) negative eigenvalues should be neglected.

source
LaplaceRedux.collect_trainableMethod
collect_trainable(model)

Collect all trainable parameter arrays from a Flux model, in the same order as Flux.destructure. This replaces the deprecated Flux.params API.

source
LaplaceRedux.compute_param_indicesMethod
compute_param_indices(model::Any, est_params::EstimationParams)

Compute the flat indices into the destructured parameter vector for the selected subset of weights. These indices are used for Jacobian/gradient column selection.

source
LaplaceRedux.convert_subnetwork_indicesMethod

convertsubnetworkindices(subnetwork_indices::AbstractArray)

Converts the subnetwork indices from the user given format [theta, row, column] to an Int i that corresponds to the index of that weight in the flattened array of weights.

source
LaplaceRedux.dataset_shapeMethod
function dataset_shape(model::LaplaceRegression, X, y)

Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.

Arguments

  • model::LaplaceModels: The Laplace model to fit.
  • X: The input data for training.
  • y: The target labels for training one-hot encoded.

Returns

  • (input size, output size)
source
LaplaceRedux.default_buildMethod
default_build( seed::Int, shape)

Builds a default MLP Flux model compatible with the dimensions of the dataset, with reproducible initial weights.

Arguments

  • seed::Int: The seed for random number generation.
  • shape: a tuple containing the dimensions of the input layer and the output layer.

Returns

  • The constructed Flux model, which consist in a simple MLP with 2 hidden layers with 20 neurons each and an input and output layers compatible with the dataset.
source
LaplaceRedux.functional_varianceMethod
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)

Computes the functional variance for the GLM predictive as map(j -> (j' * Σ * j), eachrow(𝐉)) which is a (output x output) predictive covariance matrix. Formally, we have ${\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}$ where $\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta$ is the Jacobian evaluated at the MAP estimate.

Dispatches to the appropriate method based on the Hessian structure.

source
LaplaceRedux.functional_varianceMethod

functionalvariance(la::Laplace, hessianstructure::FullHessian, 𝐉)

Computes the functional variance for the GLM predictive as map(j -> (j' * Σ * j), eachrow(𝐉)) which is a (output x output) predictive covariance matrix. Formally, we have ${\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}$ where $\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta$ is the Jacobian evaluated at the MAP estimate.

source
LaplaceRedux.functional_varianceMethod
functional_variance(la::Laplace, hessian_structure::KronHessian, 𝐉::Matrix)

Compute functional variance for the GLM predictive: as the diagonal of the K×K predictive output covariance matrix 𝐉𝐏⁻¹𝐉ᵀ, where K is the number of outputs, 𝐏 is the posterior precision, and 𝐉 is the Jacobian of model output 𝐉=∇f(x;θ)|θ̂.

source
LaplaceRedux.get_map_estimateMethod
get_map_estimate(model::Any, est_params::EstimationParams)

Helper function to extract the MAP estimate of the parameters for the model based on the subset of weights specified in the EstimationParams object.

source
LaplaceRedux.get_paramsMethod
get_params(model::Any, params::EstimationParams)

Extracts the trainable parameter arrays of a model based on the subset of weights specified in the EstimationParams object. Replaces the old Flux.params API.

source
LaplaceRedux.has_softmax_or_sigmoid_final_layerMethod
has_softmax_or_sigmoid_final_layer(model::Flux.Chain)

Check if the FLux model ends with a sigmoid or with a softmax layer

Input: - model: the Flux Chain object that represent the neural network. Return: - has_finaliser: true if the check is positive, false otherwise.

source
LaplaceRedux.instantiate_curvature!Method
instantiate_curvature!(params::EstimationParams, model::Any, likelihood::Symbol, backend::Symbol)

Instantiates the curvature interface for a Laplace approximation. The curvature interface is a concrete subtype of CurvatureInterface and is used to compute the Hessian matrix. The curvature interface is stored in the curvature field of the EstimationParams object.

source
LaplaceRedux.inv_square_formMethod

function invsquareform(K::KronDecomposed, W::Matrix)

Special function to compute the inverse square form 𝐉𝐏⁻¹𝐉ᵀ (or 𝐖𝐊⁻¹𝐖ᵀ)

source
LaplaceRedux.logdetblockMethod
logdetblock(block::Tuple{Eigen,Eigen}, delta::Number)

Log-determinant of a block in KronDecomposed, shifted by delta by on the diagonal.

source
LaplaceRedux.mmMethod

Matrix-multuply for the KronDecomposed Hessian approximation K and a 2-d matrix W, applying an exponent to K and transposing W before multiplication. Return (K^x)W^T, where x is the exponent.

source
LaplaceRedux.n_paramsMethod
n_params(model::Any, params::EstimationParams)

Helper function to determine the number of parameters of a Flux.Chain with Laplace approximation.

source
LaplaceRedux.outdimMethod
outdim(model::Chain)

Helper function to determine the output dimension of a Flux.Chain, corresponding to the number of neurons on the last layer of the NN.

source
LaplaceRedux.outdimMethod
outdim(la::AbstractLaplace)

Helper function to determine the output dimension, corresponding to the number of neurons on the last layer of the NN, of a Flux.Chain with Laplace approximation.

source
LaplaceRedux.probitMethod
probit(fμ::AbstractArray, fvar::AbstractArray)

Compute the probit approximation of the predictive distribution.

source
LinearAlgebra.detMethod
det(K::KronDecomposed)

Log-determinant of the KronDecomposed block-diagonal matrix, as the exponentiated log-determinant.

source
LinearAlgebra.logdetMethod
logdet(K::KronDecomposed)

Log-determinant of the KronDecomposed block-diagonal matrix, as the product of the determinants of the blocks

source
MLJModelInterface.fitMethod
MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)

Fit a Laplace model using the provided features and target values.

Arguments

  • m::Laplace: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted.
  • verbosity: Verbosity level for logging.
  • X: Input features, expected to be in a format compatible with MLJBase.matrix.
  • y: Target values.

Returns

  • fitresult: a tuple (la,decode) cointaing the fitted Laplace model and y[1],the first element of the categorical y vector.
  • cache: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history.
  • report: A Namedtuple containing the loss history of the fitting process.
source
MLJModelInterface.fitted_paramsMethod

function MMI.fitted_params(model::LaplaceRegressor, fitresult)

This function extracts the fitted parameters from a LaplaceRegressor model.

Arguments

  • model::LaplaceRegressor: The Laplace regression model.
  • fitresult: the Laplace approximation (la).

Returns

A named tuple containing:

  • mean: The mean of the posterior distribution.
  • H: The Hessian of the posterior distribution.
  • P: The precision matrix of the posterior distribution.
  • cov_matrix: The covariance matrix of the posterior distribution.
  • n_data: The number of data points.
  • n_params: The number of parameters.
  • n_out: The number of outputs.
  • loss: The loss value of the posterior distribution.
source
MLJModelInterface.is_same_exceptMethod
function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...)

If both m1 and m2 are of MLJType, return true if the following conditions all hold, and false otherwise:

  • typeof(m1) === typeof(m2)

  • propertynames(m1) === propertynames(m2)

  • with the exception of properties listed as exceptions or bound to an AbstractRNG, each pair of corresponding property values is either "equal" or both undefined. (If a property appears as a propertyname but not a fieldname, it is deemed as always defined.)

The meaining of "equal" depends on the type of the property value:

  • values that are themselves of MLJType are "equal" if they are equal in the sense of is_same_except with no exceptions.

  • values that are not of MLJType are "equal" if they are ==.

In the special case of a "deep" property, "equal" has a different meaning; see MLJBase.deep_properties for details.

If m1 or m2 are not MLJType objects, then return ==(m1, m2).

source
MLJModelInterface.training_lossesMethod
MMI.training_losses(model::Union{LaplaceRegressor,LaplaceClassifier}, report)

Retrieve the training loss history from the given report.

Arguments

  • model: The model for which the training losses are being retrieved.
  • report: An object containing the training report, which includes the loss history.

Returns

  • A collection representing the loss history from the training report.
source
MLJModelInterface.updateMethod
MMI.update(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)

Update the Laplace model using the provided new data points.

Arguments

  • m: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted.
  • verbosity: Verbosity level for logging.
  • X: New input features, expected to be in a format compatible with MLJBase.matrix.
  • y: New target values.

Returns

  • fitresult: a tuple (la,decode) cointaing the updated fitted Laplace model and y[1],the first element of the categorical y vector.
  • cache: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history.
  • report: A Namedtuple containing the complete loss history of the fitting process.
source
LaplaceRedux.Curvature.gradientsMethod
gradients(curvature::CurvatureInterface, X::AbstractArray, y)

Compute the gradients with respect to the loss function: ∇ℓ(f(x;θ),y) where f: ℝᴰ ↦ ℝᴷ. Returns a flat gradient vector for the selected parameter subset.

source
LaplaceRedux.Curvature.jacobians_unbatchedMethod
jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray)

Compute the Jacobian of the model output w.r.t. model parameters for the point X, without batching. Uses Flux.destructure to obtain a flat parameter vector and computes the Jacobian via Zygote.

source