Bayesian Logistic Regression

Libraries

using Pkg; Pkg.activate("docs")
# Import libraries
using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, LinearAlgebra
theme(:lime)

Data

We will use synthetic data with linearly separable samples:

# set seed
seed= 1234
Random.seed!(seed)
# Number of points to generate.
xs, ys = LaplaceRedux.Data.toy_data_linear(100; seed=seed)
X = hcat(xs...) # bring into tabular format

split in a training and test set

# Shuffle the data
n = length(ys)
indices = randperm(n)

# Define the split ratio
split_ratio = 0.8
split_index = Int(floor(split_ratio * n))

# Split the data into training and test sets
train_indices = indices[1:split_index]
test_indices = indices[split_index+1:end]

xs_train = xs[train_indices]
xs_test = xs[test_indices]
ys_train = ys[train_indices]
ys_test = ys[test_indices]
# bring into tabular format
X_train = hcat(xs_train...) 
X_test = hcat(xs_test...) 

data = zip(xs_train,ys_train)

Model

Logistic regression with weight decay can be implemented in Flux.jl as a single dense (linear) layer with binary logit crossentropy loss:

nn = Chain(Dense(2,1))
λ = 0.5
sqnorm(x) = sum(abs2, x)
weight_regularization(λ=λ) = 1/2 * λ^2 * sum(sqnorm, Flux.params(nn))
loss(x, y) = Flux.Losses.logitbinarycrossentropy(nn(x), y) + weight_regularization()

The code below simply trains the model. After about 50 training epochs training loss stagnates.

using Flux.Optimise: update!, Adam
opt = Adam()
epochs = 50
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10

for epoch = 1:epochs
  for d in data
    gs = gradient(Flux.params(nn)) do
      l = loss(d...)
    end
    update!(opt, Flux.params(nn), gs)
  end
  if epoch % show_every == 0
    println("Epoch " * string(epoch))
    @show avg_loss(data)
  end
end

Laplace approximation

Laplace approximation for the posterior predictive can be implemented as follows:

la = Laplace(nn; likelihood=:classification, λ=λ, subset_of_weights=:last_layer)
fit!(la, data)
la_untuned = deepcopy(la)   # saving for plotting
optimize_prior!(la; verbosity=1, n_steps=500)

The plot below shows the resulting posterior predictive surface for the plugin estimator (left) and the Laplace approximation (right).

zoom = 0
p_plugin = plot(la, X, ys; title="Plugin", link_approx=:plugin, clim=(0,1))
p_untuned = plot(la_untuned, X, ys; title="LA - raw (λ=$(unique(diag(la_untuned.prior.P₀))[1]))", clim=(0,1), zoom=zoom)
p_laplace = plot(la, X, ys; title="LA - tuned (λ=$(round(unique(diag(la.prior.P₀))[1],digits=2)))", clim=(0,1), zoom=zoom)
plot(p_plugin, p_untuned, p_laplace, layout=(1,3), size=(1700,400))

Now we can test the level of calibration of the neural network. First we collect the predicted results over the test dataset

 predicted_distributions= predict(la, X_test,ret_distr=true)
1×20 Matrix{Distributions.Bernoulli{Float64}}:
 Distributions.Bernoulli{Float64}(p=0.13122)  …  Distributions.Bernoulli{Float64}(p=0.109559)

then we plot the calibration plot

Calibration_Plot(la,ys_test,vec(predicted_distributions);n_bins = 10)

as we can see from the plot, although extremely accurate, the neural network does not seem to be calibrated well. This is, however, an effect of the extreme accuracy reached by the neural network which causes the lack of predictions with high uncertainty (low certainty). We can see this by looking at the level of sharpness for the two classes which are extremely close to 1, indicating the high level of trust that the neural network has in the predictions.

sharpness_classification(ys_test,vec(predicted_distributions))
(0.9131870336577175, 0.8865055827351365)