All functions and types
LaplaceRedux.AbstractDecompositionLaplaceRedux.AbstractLaplaceLaplaceRedux.AbstractLaplaceLaplaceRedux.Curvature.CurvatureInterfaceLaplaceRedux.Curvature.EmpiricalFisherLaplaceRedux.Curvature.GGNLaplaceRedux.EstimationParamsLaplaceRedux.EstimationParamsLaplaceRedux.FullHessianLaplaceRedux.HessianStructureLaplaceRedux.KronLaplaceRedux.KronDecomposedLaplaceRedux.KronHessianLaplaceRedux.LaplaceLaplaceRedux.LaplaceLaplaceRedux.LaplaceClassifierLaplaceRedux.LaplaceParamsLaplaceRedux.LaplaceRegressorLaplaceRedux.PosteriorLaplaceRedux.PosteriorLaplaceRedux.PriorLaplaceRedux.PriorBase.:*Base.:*Base.:+Base.:+Base.:+Base.:==Base.getindexBase.getindexBase.lengthLaplaceRedux.Curvature.full_batchedLaplaceRedux.Curvature.full_batchedLaplaceRedux.Curvature.full_unbatchedLaplaceRedux.Curvature.full_unbatchedLaplaceRedux.Curvature.gradientsLaplaceRedux.Curvature.jacobiansLaplaceRedux.Curvature.jacobians_batchedLaplaceRedux.Curvature.jacobians_unbatchedLaplaceRedux.Data.toy_data_linearLaplaceRedux.Data.toy_data_multiLaplaceRedux.Data.toy_data_non_linearLaplaceRedux.Data.toy_data_regressionLaplaceRedux._H_factorLaplaceRedux._fit!LaplaceRedux._fit!LaplaceRedux._init_HLaplaceRedux._weight_penaltyLaplaceRedux.approximateLaplaceRedux.approximateLaplaceRedux.clampLaplaceRedux.collect_trainableLaplaceRedux.compute_param_indicesLaplaceRedux.convert_subnetwork_indicesLaplaceRedux.dataset_shapeLaplaceRedux.decomposeLaplaceRedux.default_buildLaplaceRedux.empirical_frequency_binary_classificationLaplaceRedux.empirical_frequency_regressionLaplaceRedux.extract_mean_and_varianceLaplaceRedux.fit!LaplaceRedux.fit!LaplaceRedux.functional_varianceLaplaceRedux.functional_varianceLaplaceRedux.functional_varianceLaplaceRedux.get_loss_funLaplaceRedux.get_loss_typeLaplaceRedux.get_map_estimateLaplaceRedux.get_paramsLaplaceRedux.get_paramsLaplaceRedux.get_prior_meanLaplaceRedux.glm_predictive_distributionLaplaceRedux.has_softmax_or_sigmoid_final_layerLaplaceRedux.hessian_approximationLaplaceRedux.instantiate_curvature!LaplaceRedux.interleaveLaplaceRedux.inv_square_formLaplaceRedux.log_det_posterior_precisionLaplaceRedux.log_det_prior_precisionLaplaceRedux.log_det_ratioLaplaceRedux.log_likelihoodLaplaceRedux.log_marginal_likelihoodLaplaceRedux.logdetblockLaplaceRedux.mmLaplaceRedux.n_paramsLaplaceRedux.n_paramsLaplaceRedux.optimize_prior!LaplaceRedux.outdimLaplaceRedux.outdimLaplaceRedux.posterior_covarianceLaplaceRedux.posterior_precisionLaplaceRedux.predictLaplaceRedux.prior_precisionLaplaceRedux.probitLaplaceRedux.rescale_stddevLaplaceRedux.sharpness_classificationLaplaceRedux.sharpness_regressionLaplaceRedux.sigma_scalingLaplaceRedux.validate_subnetwork_indicesLinearAlgebra.detLinearAlgebra.logdetMLJModelInterface.fitMLJModelInterface.fitted_paramsMLJModelInterface.is_same_exceptMLJModelInterface.predictMLJModelInterface.training_lossesMLJModelInterface.updateLaplaceRedux.@zb
Exported functions
LaplaceRedux.Laplace — Type
LaplaceConcrete type for Laplace approximation. This type is a subtype of AbstractLaplace and is used to store all the necessary information for a Laplace approximation.
Fields
model::Flux.Chain: The model to be approximated.likelihood::Symbol: The likelihood function to be used.est_params::EstimationParams: The estimation parameters.prior::Prior: The parameters defining prior distribution.posterior::Posterior: The posterior distribution.
LaplaceRedux.Laplace — Method
Laplace(model::Any; likelihood::Symbol, kwargs...)Outer constructor for Laplace approximation. This function constructs a Laplace object from a given model and likelihood function.
Arguments
model::Any: The model to be approximated (a Flux.Chain).likelihood::Symbol: The likelihood function to be used. Possible values are:regressionand:classification.
Keyword Arguments
See LaplaceParams for a description of the keyword arguments.
Returns
la::Laplace: The Laplace object.
Examples
using Flux, LaplaceRedux
nn = Chain(Dense(2,1))
la = Laplace(nn, likelihood=:regression)LaplaceRedux.LaplaceClassifier — Type
LaplaceClassifierA model type for constructing a laplace classifier, based on LaplaceRedux.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceReduxDo model = LaplaceClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in LaplaceClassifier(model=...).
LaplaceClassifier implements the Laplace Redux – Effortless Bayesian Deep Learning, originally published in Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): "Laplace Redux – Effortless Bayesian Deep Learning.", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103 for classification models.
Training data
In MLJ or MLJBase, given a dataset X,y and a Flux_Chain adapted to the dataset, pass the chain to the model
laplace_model = LaplaceClassifier(model = Flux_Chain,kwargs...)then bind an instance laplace_model to data with
mach = machine(laplace_model, X, y)where
X: any table of input features (eg, aDataFrame) whose columns each have one of the following element scitypes:Continuous,Count, or<:OrderedFactor; check column scitypes withschema(X)y: is the target, which can be anyAbstractVectorwhose element scitype is<:OrderedFactoror<:Multiclass; check the scitype withscitype(y)
Train the machine using fit!(mach, rows=...).
Hyperparameters (format: name-type-default value-restrictions)
model::Union{Flux.Chain,Nothing} = nothing: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each.flux_loss = Flux.Losses.logitcrossentropy: a Flux loss functionoptimiser = Adam()a Flux optimiserepochs::Integer = 1000::(_ > 0): the number of training epochs.batch_size::Integer = 32::(_ > 0): the batch size.subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)): the subset of weights to use, either:all,:last_layer, or:subnetwork.subnetwork_indices = nothing: the indices of the subnetworks.hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal)): the structure of the Hessian matrix, either:fullor:diagonal.backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)): the backend to use, either:GGNor:EmpiricalFisher.observational_noise (alias σ)::Float64 = 1.0: the standard deviation of the prior distribution.prior_mean (alias μ₀)::Float64 = 0.0: the mean of the prior distribution.prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing: the covariance matrix of the prior distribution.fit_prior_nsteps::Int = 100::(_ > 0): the number of steps used to fit the priors.link_approx::Symbol = :probit::(_ in (:probit, :plugin)): the approximation to adopt to compute the probabilities.
Operations
predict(mach, Xnew): return predictions of the target given featuresXnewhaving the same scitype asXabove. Predictions are probabilistic, but uncalibrated.predict_mode(mach, Xnew): instead return the mode of each prediction above.
Fitted parameters
The fields of fitted_params(mach) are:
mean: The mean of the posterior distribution.H: The Hessian of the posterior distribution.P: The precision matrix of the posterior distribution.cov_matrix: The covariance matrix of the posterior distribution.n_data: The number of data points.n_params: The number of parameters.n_out: The number of outputs.loss: The loss value of the posterior distribution.
Report
The fields of report(mach) are:
loss_history: an array containing the total loss per epoch.
Accessor functions
training_losses(mach): return the loss history from report
Examples
using MLJ
LaplaceClassifier = @load LaplaceClassifier pkg=LaplaceRedux
X, y = @load_iris
# Define the Flux Chain model
using Flux
model = Chain(
Dense(4, 10, relu),
Dense(10, 10, relu),
Dense(10, 3)
)
#Define the LaplaceClassifier
model = LaplaceClassifier(model=model)
mach = machine(model, X, y) |> fit!
Xnew = (sepal_length = [6.4, 7.2, 7.4],
sepal_width = [2.8, 3.0, 2.8],
petal_length = [5.6, 5.8, 6.1],
petal_width = [2.1, 1.6, 1.9],)
yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
training_losses(mach) # loss history per epoch
pdf.(yhat, "virginica") # probabilities for the "verginica" class
fitted_params(mach) # NamedTuple with the fitted params of Laplace
See also LaplaceRedux.jl.
LaplaceRedux.LaplaceRegressor — Type
LaplaceRegressorA model type for constructing a laplace regressor, based on LaplaceRedux.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceReduxDo model = LaplaceRegressor() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in LaplaceRegressor(model=...).
LaplaceRegressor implements the Laplace Redux – Effortless Bayesian Deep Learning, originally published in Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., Hennig, P. (2021): "Laplace Redux – Effortless Bayesian Deep Learning.", NIPS'21: Proceedings of the 35th International Conference on Neural Information Processing Systems*, Article No. 1537, pp. 20089–20103 for regression models.
Training data
In MLJ or MLJBase, given a dataset X,y and a Flux_Chain adapted to the dataset, pass the chain to the model
laplace_model = LaplaceRegressor(model = Flux_Chain,kwargs...)then bind an instance laplace_model to data with
mach = machine(laplace_model, X, y)where
X: any table of input features (eg, aDataFrame) whose columns each have one of the following element scitypes:Continuous,Count, or<:OrderedFactor; check column scitypes withschema(X)y: is the target, which can be anyAbstractVectorwhose element scitype is<:Continuous; check the scitype withscitype(y)
Train the machine using fit!(mach, rows=...).
Hyperparameters (format: name-type-default value-restrictions)
model::Union{Flux.Chain,Nothing} = nothing: Either nothing or a Flux model provided by the user and compatible with the dataset. In the former case, LaplaceRedux will use a standard MLP with 2 hidden layers with 20 neurons each.flux_loss = Flux.Losses.logitcrossentropy: a Flux loss functionoptimiser = Adam()a Flux optimiserepochs::Integer = 1000::(_ > 0): the number of training epochs.batch_size::Integer = 32::(_ > 0): the batch size.subset_of_weights::Symbol = :all::(_ in (:all, :last_layer, :subnetwork)): the subset of weights to use, either:all,:last_layer, or:subnetwork.subnetwork_indices = nothing: the indices of the subnetworks.hessian_structure::Union{HessianStructure,Symbol,String} = :full::(_ in (:full, :diagonal)): the structure of the Hessian matrix, either:fullor:diagonal.backend::Symbol = :GGN::(_ in (:GGN, :EmpiricalFisher)): the backend to use, either:GGNor:EmpiricalFisher.observational_noise (alias σ)::Float64 = 1.0: the standard deviation of the prior distribution.prior_mean (alias μ₀)::Float64 = 0.0: the mean of the prior distribution.prior_precision_matrix (alias P₀)::Union{AbstractMatrix,UniformScaling,Nothing} = nothing: the covariance matrix of the prior distribution.fit_prior_nsteps::Int = 100::(_ > 0): the number of steps used to fit the priors.
Operations
predict(mach, Xnew): return predictions of the target given featuresXnewhaving the same scitype asXabove. Predictions are probabilistic, but uncalibrated.predict_mode(mach, Xnew): instead return the mode of each prediction above.
Fitted parameters
The fields of fitted_params(mach) are:
mean: The mean of the posterior distribution.H: The Hessian of the posterior distribution.P: The precision matrix of the posterior distribution.cov_matrix: The covariance matrix of the posterior distribution.n_data: The number of data points.n_params: The number of parameters.n_out: The number of outputs.
loss: The loss value of the posterior distribution.
Report
The fields of report(mach) are:
loss_history: an array containing the total loss per epoch.
Accessor functions
training_losses(mach): return the loss history from report
Examples
using MLJ
using Flux
LaplaceRegressor = @load LaplaceRegressor pkg=LaplaceRedux
model = Chain(
Dense(4, 10, relu),
Dense(10, 10, relu),
Dense(10, 1)
)
model = LaplaceRegressor(model=model)
X, y = make_regression(100, 4; noise=0.5, sparse=0.2, outliers=0.1)
mach = machine(model, X, y) |> fit!
Xnew, _ = make_regression(3, 4; rng=123)
yhat = predict(mach, Xnew) # probabilistic predictions
predict_mode(mach, Xnew) # point predictions
training_losses(mach) # loss history per epoch
fitted_params(mach) # NamedTuple with the fitted params of Laplace
See also LaplaceRedux.jl.
LaplaceRedux.empirical_frequency_binary_classification — Method
empirical_frequency_binary_classification(y_binary, distributions::Vector{Bernoulli{Float64}}; n_bins::Int=20)FOR BINARY CLASSIFICATION MODELS.
Given a calibration dataset $(x_t, y_t)$ for $i ∈ {1,...,T}$ let $p_t= H(x_t)∈[0,1]$ be the forecasted probability.
We group the $p_t$ into intervals $I_j$ for $j= 1,2,...,m$ that form a partition of [0,1]. The function computes the observed average $p_j= T^-1_j ∑_{t:p_t ∈ I_j} y_j$ in each interval $I_j$.
Source: Kuleshov, Fenner, Ermon 2018
Inputs:
- y_binary: the array of outputs $y_t$ numerically coded: 1 for the target class, 0 for the null class.
- distributions: an array of Bernoulli distributions
- n_bins: number of equally spaced bins to use.
Outputs:
- num_p_per_interval: array with the number of probabilities falling within interval.
- emp_avg: array with the observed empirical average per interval.
- bin_centers: array with the centers of the bins.
LaplaceRedux.empirical_frequency_regression — Method
empirical_frequency_regression(Y_cal, distributions::Distributions.Normal, n_bins=20)Dispatched version for Normal distributions FOR REGRESSION MODELS.
Given a calibration dataset $(x_t, y_t)$ for $i ∈ {1,...,T}$ and an array of predicted distributions, the function calculates the empirical frequency
\[p^hat_j = {y_t|F_t(y_t)<= p_j, t= 1,....,T}/T,\]
where $T$ is the number of calibration points, $p_j$ is the confidence level and $F_t$ is the cumulative distribution function of the predicted distribution targeting $y_t$.
Source: Kuleshov, Fenner, Ermon 2018
Inputs:
- Y_cal: a vector of values $y_t$
- distributions:a Vector{Distributions.Normal{Float64}} of distributions stacked row-wise.
For example the output of LaplaceRedux.predict(la,Xcal).
- `nbins`: number of equally spaced bins to use.
Outputs:
- counts: an array cointaining the empirical frequencies for each quantile interval.
LaplaceRedux.extract_mean_and_variance — Method
extract_mean_and_variance(distr::Vector{Normal{<: AbstractFloat}})Extract the mean and the variance of each distributions and return them in two separate lists.
Inputs: - distributions: a Vector of Normal distributions
Outputs: - means: the list of the means - variances: the list of the variances
LaplaceRedux.fit! — Method
fit!(la::AbstractLaplace,data)Fits the Laplace approximation for a data set. The function returns the number of observations (n_data) that were used to update the Laplace object. It does not return the updated Laplace object itself because the function modifies the input Laplace object in place (as denoted by the use of '!' in the function's name).
Examples
using Flux, LaplaceRedux
x, y = LaplaceRedux.Data.toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn)
fit!(la, data)LaplaceRedux.fit! — Method
Fit the Laplace approximation, with batched data.
LaplaceRedux.glm_predictive_distribution — Method
glm_predictive_distribution(la::AbstractLaplace, X::AbstractArray)Computes the linearized GLM predictive from neural network with a Laplace approximation to the posterior $p(\theta|\mathcal{D})=\mathcal{N}(\hat\theta,\Sigma)$. This is the distribution on network outputs given by $p(f(x)|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta})$. For the Bayesian predictive distribution, see predict.
Arguments
la::AbstractLaplace: A Laplace object.X::AbstractArray: Input data.
Returns
normal_distrA normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.-normal_distrA normal distribution N(fμ,fvar) approximating the predictive distribution p(y|X) given the input data X.fμ::AbstractArray: Mean of the predictive distribution. The output shape is column-major as in Flux.fvar::AbstractArray: Variance of the predictive distribution. The output shape is column-major as in Flux.
Examples
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
glm_predictive_distribution(la, hcat(x...))LaplaceRedux.optimize_prior! — Method
optimize_prior!(
la::AbstractLaplace;
n_steps::Int=100, lr::Real=1e-1,
λinit::Union{Nothing,Real}=nothing,
σinit::Union{Nothing,Real}=nothing
)Optimize the prior precision post-hoc through Empirical Bayes (marginal log-likelihood maximization).
LaplaceRedux.posterior_covariance — Function
posterior_covariance(la::AbstractLaplace, P=la.P)Computes the posterior covariance $∑$ as the inverse of the posterior precision: $\Sigma=P^{-1}$.
LaplaceRedux.posterior_precision — Function
posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.prior_precision_matrix)Computes the posterior precision $P$ for a fitted Laplace Approximation as follows,
$P = \sum_{n=1}^N\nabla_{\theta}^2 \log p(\mathcal{D}_n|\theta)|_{\hat\theta} + \nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}$
where $\sum_{n=1}^N\nabla_{\theta}^2\log p(\mathcal{D}_n|\theta)|_{\hat\theta}=H$ is the Hessian and $\nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}=P_0$ is the prior precision and $\hat\theta$ is the MAP estimate.
LaplaceRedux.predict — Method
predict(
la::AbstractLaplace,
X::AbstractArray;
link_approx=:probit,
predict_proba::Bool=true,
ret_distr::Bool=false,
)Computes the Bayesian predictivie distribution from a neural network with a Laplace approximation to the posterior $p(\theta | \mathcal{D}) = \mathcal{N}(\hat\theta, \Sigma)$.
Arguments
la::AbstractLaplace: A Laplace object.X::AbstractArray: Input data.link_approx::Symbol=:probit: Link function approximation. Options are:probitand:plugin.predict_proba::Bool=true: Iftrue(default) apply a sigmoid or a softmax function to the output of the Flux model.return_distr::Bool=false: iffalse(default), the function outputs either the direct output of the chain or pseudo-probabilities (ifpredict_proba=true). iftruepredict returns a probability distribution.
Returns
For classification tasks:
- If
ret_distrisfalse,predictreturnsfμ, i.e. the mean of the predictive distribution, which corresponds to the MAP estimate if the link function is set to:plugin, otherwise the probit approximation. The output shape is column-major as in Flux. - If
ret_distristrue,predictreturns a Bernoulli distribution in binary classification tasks and a categorical distribution in multiclassification tasks.
For regression tasks:
- If
ret_distrisfalse,predictreturns the mean and the variance of the predictive distribution. The output shape is column-major as in Flux. - If
ret_distristrue,predictreturns the predictive posterior distribution, namely:
$p(y|x,\mathcal{D})\approx \mathcal{N}(f(x;\hat\theta),{\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta} + \sigma^2 \mathbf{I})$
Examples
using Flux, LaplaceRedux
using LaplaceRedux.Data: toy_data_linear
x, y = toy_data_linear()
data = zip(x,y)
nn = Chain(Dense(2,1))
la = Laplace(nn; likelihood=:classification)
fit!(la, data)
predict(la, hcat(x...))LaplaceRedux.rescale_stddev — Method
rescale_stddev(distr::Vector{Normal{T}}, s::T) where {T<:AbstractFloat}Rescale the standard deviation of the Normal distributions received as argument and return a vector of rescaled Normal distributions. Inputs:
- distr: a Vector of Normal distributions - s: a scale factor of type T.
Outputs:
- Vector{Normal{T}}: a Vector of rescaled Normal distributions.
LaplaceRedux.sharpness_classification — Method
sharpness_classification(y_binary,distributions::Distributions.Bernoulli)dispatched for Bernoulli Distributions FOR BINARY CLASSIFICATION MODELS.
Assess the sharpness of the model by looking at the distribution of model predictions. When forecasts are sharp, most predictions are close to either 0 or 1
Source: Kuleshov, Fenner, Ermon 2018
Inputs:
- y_binary: the array of outputs $y_t$ numerically coded: 1 for the target class, 0 for the negative result.
- distributions: an array of Bernoulli distributions describing the probability of of the output belonging to the target class
Outputs:
- mean_class_one: a scalar that measure the average prediction for the target class
- mean_class_zero: a scalar that measure the average prediction for the null class
LaplaceRedux.sharpness_regression — Method
sharpness_regression(distributions::Distributions.Normal)Dispatched version for Normal distributions FOR REGRESSION MODELS.
Given a calibration dataset $(x_t, y_t)$ for $i ∈ {1,...,T}$ and an array of predicted distributions, the function calculates the sharpness of the predicted distributions, i.e., the average of the variances $\sigma^2(F_t)$ predicted by the forecaster for each $x_t$.
source: Kuleshov, Fenner, Ermon 2018
Inputs:
- distributions: an array of normal distributions $F(x_t)$ stacked row-wise.
Outputs:
- sharpness: a scalar that measure the level of sharpness of the regressor
LaplaceRedux.sigma_scaling — Method
sigma_scaling(distr::Vector{Normal{T}}, y_cal::Vector{<:AbstractFloat}) where T <: AbstractFloatCompute the value of Σ that maximize the conditional log-likelihood:
\[ m ln(Σ) +1/2 * Σ^{-2} ∑_{i=1}^{i=m} || y_cal_i - ̄y_mean_i ||^2 / σ^2_i \]
where m is the number of elements in the calibration set (xcal,ycal).
Source: Laves,Ihler,Fast, Kahrs, Ortmaier,2020 Inputs:
- distr: a Vector of Normal distributions
- y_cal: a Vector of true results.
Outputs:
- sigma: the scalar that maximize the likelihood.
MLJModelInterface.predict — Method
function MMI.predict(m::LaplaceRegressor, fitresult, Xnew)
Predicts the response for new data using a fitted Laplace model.
Arguments
m::LaplaceRegressor: The Laplace model.fitresult: The result of the fitting procedure.Xnew: The new data for which predictions are to be made.
Returns
for LaplaceRegressor:
- An array of Normal distributions, each centered around the predicted mean and variance for the corresponding input in `Xnew`.
for LaplaceClassifier:
- `MLJBase.UnivariateFinite`: The predicted class probabilities for the new data.LaplaceRedux.Curvature.CurvatureInterface — Type
Base type for any curvature interface.
Internal functions
LaplaceRedux.AbstractDecomposition — Type
Abstract type of Hessian decompositions.
LaplaceRedux.AbstractLaplace — Type
Abstract base type for all Laplace approximations in this library. All subclasses implemented are parametric.
LaplaceRedux.AbstractLaplace — Method
(la::AbstractLaplace)(X::AbstractArray)Calling a model with Laplace Approximation on an array of inputs is equivalent to explicitly calling the predict function.
LaplaceRedux.EstimationParams — Type
EstimationParamsContainer for the parameters of a Laplace approximation.
Fields
subset_of_weights::Symbol: the subset of weights to consider. Possible values are:all,:last_layer, and:subnetwork.subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values arenothingor a vector of vectors of integers.hessian_structure::HessianStructure: the structure of the Hessian. Possible values are:fulland:kronor a concrete subtype ofHessianStructure.curvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values arenothingor a concrete subtype ofCurvatureInterface.
LaplaceRedux.EstimationParams — Method
EstimationParams(params::LaplaceParams)Extracts the estimation parameters from a LaplaceParams object.
LaplaceRedux.FullHessian — Type
Concrete type for full Hessian structure. This is the default structure.
LaplaceRedux.HessianStructure — Type
Abstract type for Hessian structure.
LaplaceRedux.Kron — Type
Kronecker-factored approximate curvature representation for a neural network model. Each element in kfacs represents two Kronecker factors (𝐆, 𝐀), such that the full block Hessian approximation would be approximated as 𝐀⊗𝐆.
LaplaceRedux.KronDecomposed — Type
KronDecomposedDecomposed Kronecker-factored approximate curvature representation for a neural network model.
Decomposition is required to add the prior (diagonal matrix) to the posterior (KronDecomposed). It also has the benefits of reducing the costs for computation of inverses and log-determinants.
LaplaceRedux.KronHessian — Type
Concrete type for Kronecker-factored Hessian structure.
LaplaceRedux.LaplaceParams — Type
LaplaceParamsContainer for the parameters of a Laplace approximation.
Fields
subset_of_weights::Symbol: the subset of weights to consider. Possible values are:all,:last_layer, and:subnetwork.subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}: the indices of the subnetwork. Possible values arenothingor a vector of vectors of integers.hessian_structure::HessianStructure: the structure of the Hessian. Possible values are:fulland:kronor a concrete subtype ofHessianStructure.backend::Symbol: the backend to use. Possible values are:GGNand:Fisher.curvature::Union{Curvature.CurvatureInterface,Nothing}: the curvature interface. Possible values arenothingor a concrete subtype ofCurvatureInterface.observational_noise::Real: the observation noiseσ::Real: alias forobservational_noise.prior_mean::Real: the prior mean of the network parameters.μ₀::Real: alias forprior_mean.prio_precision::Real: the prior precision for the network parameters.λ::Real: alias forprior_precision.prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix for the network parameters.P₀::Union{Nothing,AbstractMatrix,UniformScaling}: alias forprior_precision_matrix.
LaplaceRedux.Posterior — Type
PosteriorContainer for the results of a Laplace approximation.
Fields
posterior_mean::AbstractVector: the MAP estimate of the parametersH::Union{AbstractArray,AbstractDecomposition,Nothing}: the Hessian matrixP::Union{AbstractArray,AbstractDecomposition,Nothing}: the posterior precision matrixposterior_covariance_matrix::Union{AbstractArray,Nothing}: the posterior covariance matrixn_data::Union{Int,Nothing}: the number of data pointsn_params::Union{Int,Nothing}: the number of parametersn_out::Union{Int,Nothing}: the number of outputsloss::Real: the loss value
LaplaceRedux.Posterior — Method
Posterior(model::Any, est_params::EstimationParams)Outer constructor for Posterior object.
LaplaceRedux.Prior — Type
PriorContainer for the prior parameters of a Laplace approximation.
Fields
observational_noise::Real: the observation noiseprior_mean::Real: the prior meanprior_precision::Real: the prior precisionprior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}: the prior precision matrix
LaplaceRedux.Prior — Method
Prior(params::LaplaceParams)Extracts the prior parameters from a LaplaceParams object.
Base.getindex — Method
Get Kronecker-factored block represenation.
Base.getindex — Method
Get i-th block of a a Kronecker-factored curvature.
Base.length — Method
Number of blocks in a Kronecker-factored curvature.
LaplaceRedux._H_factor — Method
_H_factor(la::AbstractLaplace)Returns the factor σ⁻², where σ is used in the zero-centered Gaussian prior p(θ) = N(θ;0,σ²I)
LaplaceRedux._fit! — Method
_fit!(la::Laplace, hessian_structure::FullHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)Fit a Laplace approximation to the posterior distribution of a model using the full Hessian.
LaplaceRedux._fit! — Method
_fit!(la::Laplace, hessian_structure::KronHessian, data; batched::Bool=false, batchsize::Int, override::Bool=true)Fit a Laplace approximation to the posterior distribution of a model using the Kronecker-factored Hessian.
LaplaceRedux._init_H — Method
_init_H(la::AbstractLaplace)LaplaceRedux._weight_penalty — Method
_weight_penalty(la::AbstractLaplace)The weight penalty term is a regularization term used to prevent overfitting. Weight regularization methods such as weight decay introduce a penalty to the loss function when training a neural network to encourage the network to use small weights. Smaller weights in a neural network can result in a model that is more stable and less likely to overfit the training dataset, in turn having better performance when making a prediction on new data.
LaplaceRedux.approximate — Method
approximate(curvature::CurvatureInterface, hessian_structure::FullHessian, d::Tuple; batched::Bool=false)Compute the full approximation, for either a single input-output datapoint or a batch of such.
LaplaceRedux.approximate — Method
approximate(curvature::CurvatureInterface, hessian_structure::KronHessian, data; batched::Bool=false)Compute the eigendecomposed Kronecker-factored approximate curvature as the Fisher information matrix.
Note, since the network predictive distribution is used in a weighted sum, and the number of backward passes is linear in the number of target classes, e.g. 100 for CIFAR-100.
LaplaceRedux.clamp — Method
Clamp eigenvalues in an eigendecomposition to be non-negative.
Since the Fisher information matrix is a positive-semidefinite by construction, the (near-zero) negative eigenvalues should be neglected.
LaplaceRedux.collect_trainable — Method
collect_trainable(model)Collect all trainable parameter arrays from a Flux model, in the same order as Flux.destructure. This replaces the deprecated Flux.params API.
LaplaceRedux.compute_param_indices — Method
compute_param_indices(model::Any, est_params::EstimationParams)Compute the flat indices into the destructured parameter vector for the selected subset of weights. These indices are used for Jacobian/gradient column selection.
LaplaceRedux.convert_subnetwork_indices — Method
convertsubnetworkindices(subnetwork_indices::AbstractArray)
Converts the subnetwork indices from the user given format [theta, row, column] to an Int i that corresponds to the index of that weight in the flattened array of weights.
LaplaceRedux.dataset_shape — Method
function dataset_shape(model::LaplaceRegression, X, y)Compute the the number of features of the X input dataset and the number of variables to predict from the y output dataset.
Arguments
model::LaplaceModels: The Laplace model to fit.X: The input data for training.y: The target labels for training one-hot encoded.
Returns
- (input size, output size)
LaplaceRedux.decompose — Method
decompose(K::Kron)Eigendecompose Kronecker factors and turn into KronDecomposed.
LaplaceRedux.default_build — Method
default_build( seed::Int, shape)Builds a default MLP Flux model compatible with the dimensions of the dataset, with reproducible initial weights.
Arguments
seed::Int: The seed for random number generation.shape: a tuple containing the dimensions of the input layer and the output layer.
Returns
- The constructed Flux model, which consist in a simple MLP with 2 hidden layers with 20 neurons each and an input and output layers compatible with the dataset.
LaplaceRedux.functional_variance — Method
functional_variance(la::AbstractLaplace, 𝐉::AbstractArray)Computes the functional variance for the GLM predictive as map(j -> (j' * Σ * j), eachrow(𝐉)) which is a (output x output) predictive covariance matrix. Formally, we have ${\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}$ where $\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta$ is the Jacobian evaluated at the MAP estimate.
Dispatches to the appropriate method based on the Hessian structure.
LaplaceRedux.functional_variance — Method
functionalvariance(la::Laplace, hessianstructure::FullHessian, 𝐉)
Computes the functional variance for the GLM predictive as map(j -> (j' * Σ * j), eachrow(𝐉)) which is a (output x output) predictive covariance matrix. Formally, we have ${\mathbf{J}_{\hat\theta}}^\intercal\Sigma\mathbf{J}_{\hat\theta}$ where $\mathbf{J}_{\hat\theta}=\nabla_{\theta}f(x;\theta)|\hat\theta$ is the Jacobian evaluated at the MAP estimate.
LaplaceRedux.functional_variance — Method
functional_variance(la::Laplace, hessian_structure::KronHessian, 𝐉::Matrix)Compute functional variance for the GLM predictive: as the diagonal of the K×K predictive output covariance matrix 𝐉𝐏⁻¹𝐉ᵀ, where K is the number of outputs, 𝐏 is the posterior precision, and 𝐉 is the Jacobian of model output 𝐉=∇f(x;θ)|θ̂.
LaplaceRedux.get_loss_fun — Method
get_loss_fun(likelihood::Symbol)Helper function to choose loss function based on specified model likelihood.
LaplaceRedux.get_loss_type — Method
get_loss_type(likelihood::Symbol)Choose loss function type based on specified model likelihood.
LaplaceRedux.get_map_estimate — Method
get_map_estimate(model::Any, est_params::EstimationParams)Helper function to extract the MAP estimate of the parameters for the model based on the subset of weights specified in the EstimationParams object.
LaplaceRedux.get_params — Method
get_params(model::Any, params::EstimationParams)Extracts the trainable parameter arrays of a model based on the subset of weights specified in the EstimationParams object. Replaces the old Flux.params API.
LaplaceRedux.get_params — Method
get_params(la::Laplace)Returns the selected trainable parameter arrays for a Laplace object.
LaplaceRedux.get_prior_mean — Method
get_prior_mean(la::Laplace)Helper function to extract the prior mean of the parameters from a Laplace approximation.
LaplaceRedux.has_softmax_or_sigmoid_final_layer — Method
has_softmax_or_sigmoid_final_layer(model::Flux.Chain)Check if the FLux model ends with a sigmoid or with a softmax layer
Input: - model: the Flux Chain object that represent the neural network. Return: - has_finaliser: true if the check is positive, false otherwise.
LaplaceRedux.hessian_approximation — Method
hessian_approximation(la::AbstractLaplace, d; batched::Bool=false)Computes the local Hessian approximation at a single datapoint d.
LaplaceRedux.instantiate_curvature! — Method
instantiate_curvature!(params::EstimationParams, model::Any, likelihood::Symbol, backend::Symbol)Instantiates the curvature interface for a Laplace approximation. The curvature interface is a concrete subtype of CurvatureInterface and is used to compute the Hessian matrix. The curvature interface is stored in the curvature field of the EstimationParams object.
LaplaceRedux.interleave — Method
Interleave elements of multiple iterables in order provided.
LaplaceRedux.inv_square_form — Method
function invsquareform(K::KronDecomposed, W::Matrix)
Special function to compute the inverse square form 𝐉𝐏⁻¹𝐉ᵀ (or 𝐖𝐊⁻¹𝐖ᵀ)
LaplaceRedux.log_det_posterior_precision — Method
log_det_posterior_precision(la::AbstractLaplace)LaplaceRedux.log_det_prior_precision — Method
log_det_prior_precision(la::AbstractLaplace)LaplaceRedux.log_det_ratio — Method
log_det_ratio(la::AbstractLaplace)LaplaceRedux.log_likelihood — Method
log_likelihood(la::AbstractLaplace)LaplaceRedux.log_marginal_likelihood — Method
log_marginal_likelihood(la::AbstractLaplace; P₀::Union{Nothing,UniformScaling}=nothing, σ::Union{Nothing, Real}=nothing)LaplaceRedux.logdetblock — Method
logdetblock(block::Tuple{Eigen,Eigen}, delta::Number)Log-determinant of a block in KronDecomposed, shifted by delta by on the diagonal.
LaplaceRedux.mm — Method
Matrix-multuply for the KronDecomposed Hessian approximation K and a 2-d matrix W, applying an exponent to K and transposing W before multiplication. Return (K^x)W^T, where x is the exponent.
LaplaceRedux.n_params — Method
n_params(model::Any, params::EstimationParams)Helper function to determine the number of parameters of a Flux.Chain with Laplace approximation.
LaplaceRedux.n_params — Method
LaplaceRedux.n_params(la::Laplace)Overloads the n_params function for a Laplace object.
LaplaceRedux.outdim — Method
outdim(model::Chain)Helper function to determine the output dimension of a Flux.Chain, corresponding to the number of neurons on the last layer of the NN.
LaplaceRedux.outdim — Method
outdim(la::AbstractLaplace)Helper function to determine the output dimension, corresponding to the number of neurons on the last layer of the NN, of a Flux.Chain with Laplace approximation.
LaplaceRedux.prior_precision — Method
prior_precision(la::Laplace)Helper function to extract the prior precision matrix from a Laplace approximation.
LaplaceRedux.probit — Method
probit(fμ::AbstractArray, fvar::AbstractArray)Compute the probit approximation of the predictive distribution.
LaplaceRedux.validate_subnetwork_indices — Method
validatesubnetworkindices( subnetwork_indices::Union{Nothing,Vector{Vector{Int}}}, params )
Determines whether subnetwork_indices is a valid input for specified parameters.
LinearAlgebra.det — Method
det(K::KronDecomposed)Log-determinant of the KronDecomposed block-diagonal matrix, as the exponentiated log-determinant.
LinearAlgebra.logdet — Method
logdet(K::KronDecomposed)Log-determinant of the KronDecomposed block-diagonal matrix, as the product of the determinants of the blocks
MLJModelInterface.fit — Method
MMI.fit(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)Fit a Laplace model using the provided features and target values.
Arguments
m::Laplace: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted.verbosity: Verbosity level for logging.X: Input features, expected to be in a format compatible with MLJBase.matrix.y: Target values.
Returns
fitresult: a tuple (la,decode) cointaing the fitted Laplace model and y[1],the first element of the categorical y vector.cache: a tuple containing a deepcopy of the model, the current state of the optimiser and the training loss history.report: A Namedtuple containing the loss history of the fitting process.
MLJModelInterface.fitted_params — Method
function MMI.fitted_params(model::LaplaceRegressor, fitresult)
This function extracts the fitted parameters from a LaplaceRegressor model.
Arguments
model::LaplaceRegressor: The Laplace regression model.fitresult: the Laplace approximation (la).
Returns
A named tuple containing:
mean: The mean of the posterior distribution.H: The Hessian of the posterior distribution.P: The precision matrix of the posterior distribution.cov_matrix: The covariance matrix of the posterior distribution.n_data: The number of data points.n_params: The number of parameters.n_out: The number of outputs.loss: The loss value of the posterior distribution.
MLJModelInterface.is_same_except — Method
function MMI.is_same_except(m1::LaplaceModels, m2::LaplaceModels, exceptions::Symbol...)If both m1 and m2 are of MLJType, return true if the following conditions all hold, and false otherwise:
typeof(m1) === typeof(m2)propertynames(m1) === propertynames(m2)with the exception of properties listed as
exceptionsor bound to anAbstractRNG, each pair of corresponding property values is either "equal" or both undefined. (If a property appears as apropertynamebut not afieldname, it is deemed as always defined.)
The meaining of "equal" depends on the type of the property value:
values that are themselves of
MLJTypeare "equal" if they are equal in the sense ofis_same_exceptwith no exceptions.values that are not of
MLJTypeare "equal" if they are==.
In the special case of a "deep" property, "equal" has a different meaning; see MLJBase.deep_properties for details.
If m1 or m2 are not MLJType objects, then return ==(m1, m2).
MLJModelInterface.training_losses — Method
MMI.training_losses(model::Union{LaplaceRegressor,LaplaceClassifier}, report)Retrieve the training loss history from the given report.
Arguments
model: The model for which the training losses are being retrieved.report: An object containing the training report, which includes the loss history.
Returns
- A collection representing the loss history from the training report.
MLJModelInterface.update — Method
MMI.update(m::Union{LaplaceRegressor,LaplaceClassifier}, verbosity, X, y)Update the Laplace model using the provided new data points.
Arguments
m: The Laplace (LaplaceRegressor or LaplaceClassifier) model to be fitted.verbosity: Verbosity level for logging.X: New input features, expected to be in a format compatible with MLJBase.matrix.y: New target values.
Returns
fitresult: a tuple (la,decode) cointaing the updated fitted Laplace model and y[1],the first element of the categorical y vector.cache: a tuple containing a deepcopy of the model, the updated current state of the optimiser and training loss history.report: A Namedtuple containing the complete loss history of the fitting process.
LaplaceRedux.@zb — Macro
Macro for zero-based indexing. Example of usage: (@zb A[0]) = ...
LaplaceRedux.Curvature.EmpiricalFisher — Type
Constructor for curvature approximated by empirical Fisher.
LaplaceRedux.Curvature.GGN — Type
Constructor for curvature approximated by Generalized Gauss-Newton.
LaplaceRedux.Curvature.full_batched — Method
full_batched(curvature::EmpiricalFisher, d::Tuple)Compute the full empirical Fisher for batch of inputs-outputs, with the batch dimension at the end.
LaplaceRedux.Curvature.full_batched — Method
full_batched(curvature::GGN, d::Tuple)Compute the full GGN for batch of inputs-outputs, with the batch dimension at the end.
LaplaceRedux.Curvature.full_unbatched — Method
full_unbatched(curvature::EmpiricalFisher, d::Tuple)Compute the full empirical Fisher for a single datapoint.
LaplaceRedux.Curvature.full_unbatched — Method
full_unbatched(curvature::GGN, d::Tuple)Compute the full GGN for a singular input-ouput datapoint.
LaplaceRedux.Curvature.gradients — Method
gradients(curvature::CurvatureInterface, X::AbstractArray, y)Compute the gradients with respect to the loss function: ∇ℓ(f(x;θ),y) where f: ℝᴰ ↦ ℝᴷ. Returns a flat gradient vector for the selected parameter subset.
LaplaceRedux.Curvature.jacobians — Method
jacobians(curvature::CurvatureInterface, X::AbstractArray; batched::Bool=false)Computes the Jacobian ∇f(x;θ) where f: ℝᴰ ↦ ℝᴷ.
LaplaceRedux.Curvature.jacobians_batched — Method
jacobians_batched(curvature::CurvatureInterface, X::AbstractArray)Compute Jacobians of the model output w.r.t. model parameters for points in X, with batching.
LaplaceRedux.Curvature.jacobians_unbatched — Method
jacobians_unbatched(curvature::CurvatureInterface, X::AbstractArray)Compute the Jacobian of the model output w.r.t. model parameters for the point X, without batching. Uses Flux.destructure to obtain a flat parameter vector and computes the Jacobian via Zygote.
LaplaceRedux.Data.toy_data_linear — Function
toy_data_linear(N=100)Examples
toy_data_linear()LaplaceRedux.Data.toy_data_multi — Function
toy_data_multi(N=100)Examples
toy_data_multi()LaplaceRedux.Data.toy_data_non_linear — Function
toy_data_non_linear(N=100)Examples
toy_data_non_linear()LaplaceRedux.Data.toy_data_regression — Function
toy_data_regression(N=25, p=1; noise=0.3, fun::Function=f(x)=sin(2 * π * x))A helper function to generate synthetic data for regression.