Libraries
Import the libraries required to run this example
using Pkg; Pkg.activate("docs")
# Import libraries
using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux
theme(:wong)Data
We first generate some synthetic data:
using LaplaceRedux.Data
n = 3000 # number of observations
σtrue = 0.30 # true observational noise
x, y = Data.toy_data_regression(n;noise=σtrue,seed=1234)
xs = [[x] for x in x]
X = permutedims(x)and split them in a training set and a test set
# Shuffle the data
Random.seed!(1234) # Set a seed for reproducibility
shuffle_indices = shuffle(1:n)
# Define split ratios
train_ratio = 0.8
test_ratio = 0.2
# Calculate split indices
train_end = Int(floor(train_ratio * n))
# Split the data
train_indices = shuffle_indices[1:train_end]
test_indices = shuffle_indices[train_end+1:end]
# Create the splits
x_train, y_train = x[train_indices], y[train_indices]
x_test, y_test = x[test_indices], y[test_indices]
# Optional: Convert to desired format
xs_train = [[x] for x in x_train]
xs_test = [[x] for x in x_test]
X_train = permutedims(x_train)
X_test = permutedims(x_test)MLP
We set up a model and loss with weight regularization:
train_data = zip(xs_train,y_train)
n_hidden = 50
D = size(X,1)
nn = Chain(
Dense(D, n_hidden, tanh),
Dense(n_hidden, 1)
)
loss(x, y) = Flux.Losses.mse(nn(x), y)We train the model:
using Flux.Optimise: update!, Adam
opt = Adam(1e-3)
epochs = 1000
avg_loss(train_data) = mean(map(d -> loss(d[1],d[2]), train_data))
show_every = epochs/10
for epoch = 1:epochs
for d in train_data
gs = gradient(Flux.params(nn)) do
l = loss(d...)
end
update!(opt, Flux.params(nn), gs)
end
if epoch % show_every == 0
println("Epoch " * string(epoch))
@show avg_loss(train_data)
end
endLaplace Approximation
Laplace approximation can be implemented as follows:
subset_w = :all
la = Laplace(nn; likelihood=:regression, subset_of_weights=subset_w)
fit!(la, train_data)
plot(la, X_train, y_train; zoom=-5, size=(400,400))Next we optimize the prior precision $P_0$ and and observational noise $\sigma$ using Empirical Bayes:
optimize_prior!(la; verbosity=1)
plot(la, X_train, y_train; zoom=-5, size=(400,400))loss(exp.(logP₀), exp.(logσ)) = 668.3714946472106
Log likelihood: -618.5175117610522
Log det ratio: 68.76532606873238
Scatter: 30.942639703584522
loss(exp.(logP₀), exp.(logσ)) = 719.2536119935747
Log likelihood: -673.0996963447847
Log det ratio: 76.53255037599948
Scatter: 15.775280921580569
loss(exp.(logP₀), exp.(logσ)) = 574.605864472924
Log likelihood: -528.694286608232
Log det ratio: 80.73114330857285
Scatter: 11.092012420811196
loss(exp.(logP₀), exp.(logσ)) = 568.4433850825203
Log likelihood: -522.4407550111031
Log det ratio: 82.10089958560243
Scatter: 9.90436055723207
loss(exp.(logP₀), exp.(logσ)) = 566.9485255672008
Log likelihood: -520.9682443835385
Log det ratio: 81.84516297272847
Scatter: 10.11539939459612
loss(exp.(logP₀), exp.(logσ)) = 559.9852101992792
Log likelihood: -514.0625630685765
Log det ratio: 80.97813304453496
Scatter: 10.867161216870441
loss(exp.(logP₀), exp.(logσ)) = 559.1404593114019
Log likelihood: -513.2449017869876
Log det ratio: 80.16026747795866
Scatter: 11.630847570869795
loss(exp.(logP₀), exp.(logσ)) = 559.3201392562346
Log likelihood: -513.4273312363501
Log det ratio: 79.68892769076004
Scatter: 12.096688349008877
loss(exp.(logP₀), exp.(logσ)) = 559.2111983983311
Log likelihood: -513.3174948065804
Log det ratio: 79.56631681347287
Scatter: 12.2210903700287
loss(exp.(logP₀), exp.(logσ)) = 559.1107459310829
Log likelihood: -513.2176579845662
Log det ratio: 79.63946732368183
Scatter: 12.146708569351494Calibration Plot
Once the prior precision has been optimized it is possible to evaluate the quality of the predictive distribution obtained through a calibration plot and a test dataset (ytest, Xtest).
First, we apply the trained network on the test dataset (ytest, Xtest) and collect the neural network’s predicted distributions
predicted_distributions= predict(la, X_test,ret_distr=true)600×1 Matrix{Distributions.Normal{Float64}}:
Distributions.Normal{Float64}(μ=-0.1137533187866211, σ=0.07161056521032018)
Distributions.Normal{Float64}(μ=0.7063850164413452, σ=0.050697938829269665)
Distributions.Normal{Float64}(μ=-0.2211049497127533, σ=0.06876939416479119)
Distributions.Normal{Float64}(μ=0.720299243927002, σ=0.08665125572287981)
Distributions.Normal{Float64}(μ=-0.8338974714279175, σ=0.06464012115237727)
Distributions.Normal{Float64}(μ=0.9910320043563843, σ=0.07452060172164382)
Distributions.Normal{Float64}(μ=0.1507074236869812, σ=0.07316299850461126)
Distributions.Normal{Float64}(μ=0.20875799655914307, σ=0.05507748397231652)
Distributions.Normal{Float64}(μ=0.973572850227356, σ=0.07899004963915071)
Distributions.Normal{Float64}(μ=0.9497100114822388, σ=0.07750126389821968)
Distributions.Normal{Float64}(μ=0.22462180256843567, σ=0.07103664786246695)
Distributions.Normal{Float64}(μ=-0.7654240131378174, σ=0.05501397704409917)
Distributions.Normal{Float64}(μ=1.0029183626174927, σ=0.07619466916431794)
⋮
Distributions.Normal{Float64}(μ=0.7475956678390503, σ=0.049875919157527815)
Distributions.Normal{Float64}(μ=0.019430622458457947, σ=0.07445076746045155)
Distributions.Normal{Float64}(μ=-0.9451781511306763, σ=0.05929712369810892)
Distributions.Normal{Float64}(μ=-0.9813591241836548, σ=0.05844012710417755)
Distributions.Normal{Float64}(μ=-0.6470385789871216, σ=0.055754609087554294)
Distributions.Normal{Float64}(μ=-0.34288135170936584, σ=0.05533523375842789)
Distributions.Normal{Float64}(μ=0.9912381172180176, σ=0.07872473667398772)
Distributions.Normal{Float64}(μ=-0.824547290802002, σ=0.05499258101374759)
Distributions.Normal{Float64}(μ=-0.3306621015071869, σ=0.06745251908756716)
Distributions.Normal{Float64}(μ=0.3742436170578003, σ=0.10588913330223387)
Distributions.Normal{Float64}(μ=0.0875578224658966, σ=0.07436153828228255)
Distributions.Normal{Float64}(μ=-0.34871187806129456, σ=0.06742745343084512)then we can plot the calibration plot of our neural model
Calibration_Plot(la,y_test,vec(predicted_distributions);n_bins = 20)and compute the sharpness of the predictive distribution
sharpness_regression(vec(predicted_distributions))0.005058067743863281