Bayesian MLP

Libraries

using Pkg; Pkg.activate("docs")
# Import libraries
using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra
theme(:lime)

Data

This time we use a synthetic dataset containing samples that are not linearly separable:

#set seed
seed = 1234
Random.seed!(seed)
# Number of points to generate.
xs, ys = LaplaceRedux.Data.toy_data_non_linear(400; seed = seed)
# Shuffle the data
n = length(ys)
indices = randperm(n)

# Define the split ratio
split_ratio = 0.8
split_index = Int(floor(split_ratio * n))

# Split the data into training and test sets
train_indices = indices[1:split_index]
test_indices = indices[split_index+1:end]

xs_train = xs[train_indices]
xs_test = xs[test_indices]
ys_train = ys[train_indices]
ys_test = ys[test_indices]
# bring into tabular format
X_train = hcat(xs_train...) 
X_test = hcat(xs_test...) 

data = zip(xs_train,ys_train)

Model

For the classification task we build a neural network with weight decay composed of a single hidden layer.

n_hidden = 10
D = size(X_train,1)
nn = Chain(
    Dense(D, n_hidden, σ),
    Dense(n_hidden, 1)
)  
loss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) 

The model is trained until training loss stagnates.

using Flux.Optimise: update!, Adam
opt = Adam(1e-3)
epochs = 100
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10

for epoch = 1:epochs
  for d in data
    gs = gradient(Flux.params(nn)) do
      l = loss(d...)
    end
    update!(opt, Flux.params(nn), gs)
  end
  if epoch % show_every == 0
    println("Epoch " * string(epoch))
    @show avg_loss(data)
  end
end

Laplace Approximation

Laplace approximation can be implemented as follows:

la = Laplace(nn; likelihood=:classification, subset_of_weights=:all)
fit!(la, data)
la_untuned = deepcopy(la)   # saving for plotting
optimize_prior!(la; verbosity=1, n_steps=500)

The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).

# Plot the posterior distribution with a contour plot.
zoom=0
p_plugin = plot(la, X_train, ys_train; title="Plugin", link_approx=:plugin, clim=(0,1))
p_untuned = plot(la_untuned, X_train, ys_train; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom)
p_laplace = plot(la, X_train, ys_train; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom)
plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

Zooming out we can note that the plugin estimator produces high-confidence estimates in regions scarce of any samples. The Laplace approximation is much more conservative about these regions.

zoom=-50
p_plugin = plot(la, X_train, ys_train; title="Plugin", link_approx=:plugin, clim=(0,1))
p_untuned = plot(la_untuned, X_train, ys_train; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom)
p_laplace = plot(la, X_train, ys_train; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom)
plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

We plot now the calibration plot to assess the level of average calibration reached by the neural network.

predicted_distributions= predict(la, X_test,ret_distr=true)
Calibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)

and the sharpness score

sharpness_classification(ys_test,vec(predicted_distributions))
(0.9277189055456709, 0.9196132560599691)