Multi-class problem

Libraries

using Pkg; Pkg.activate("docs")
# Import libraries
using Flux, Plots, TaijaPlotting, Random, Statistics, LaplaceRedux, CategoricalDistributions
theme(:lime)

Data

using LaplaceRedux.Data
seed = 1234
x, y = Data.toy_data_multi(seed=seed)
X = hcat(x...)
y_onehot = Flux.onehotbatch(y, unwrap.(unique(y)))
y_onehot = Flux.unstack(y_onehot',1)

split in training and test datasets

# Shuffle the data
Random.seed!(seed)
n = length(y)
indices = randperm(n)

# Define the split ratio
split_ratio = 0.8
split_index = Int(floor(split_ratio * n))

# Split the data into training and test sets
train_indices = indices[1:split_index]
test_indices = indices[split_index+1:end]

x_train = x[train_indices]
x_test = x[test_indices]
y_onehot_train = y_onehot[train_indices,:]
y_onehot_test = y_onehot[test_indices,:]

y_train = vec(y[train_indices,:])
y_test = vec(y[test_indices,:])
# bring into tabular format
X_train = hcat(x_train...) 
X_test = hcat(x_test...) 

data = zip(x_train,y_onehot_train)
#data = zip(x,y_onehot)

MLP

We set up a model

n_hidden = 3
D = size(X,1)
out_dim = length(unwrap.(unique(y)))
nn = Chain(
    Dense(D, n_hidden, σ),
    Dense(n_hidden, out_dim)
)  
loss(x, y) = Flux.Losses.logitcrossentropy(nn(x), y)

training:

using Flux.Optimise: update!, Adam
opt = Adam()
epochs = 100
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
show_every = epochs/10

for epoch = 1:epochs
    for d in data
        gs = gradient(Flux.params(nn)) do
            l = loss(d...)
        end
        update!(opt, Flux.params(nn), gs)
    end
    if epoch % show_every == 0
        println("Epoch " * string(epoch))
        @show avg_loss(data)
    end
end

Laplace Approximation

The Laplace approximation can be implemented as follows:

la = Laplace(nn; likelihood=:classification)
fit!(la, data)
optimize_prior!(la; verbosity=1, n_steps=100)

with either the probit approximation:

_labels = sort(unwrap.(unique(y)))
plt_list = []
for target in _labels
    plt = plot(la, X_test, y_test; target=target, clim=(0,1))
    push!(plt_list, plt)
end
plot(plt_list...)

or the plugin approximation:

_labels = sort(unwrap.(unique(y)))
plt_list = []
for target in _labels
    plt = plot(la, X_test, y_test; target=target, clim=(0,1), link_approx=:plugin)
    push!(plt_list, plt)
end
plot(plt_list...)

Calibration Plots

In the case of multiclass classification tasks, we cannot plot the calibration plots directly since they can only be used in the binary classification case. However, we can use them to plot the calibration of the predictions for 1 class against all the others. To do so, we first have to collect the predicted categorical distributions

predicted_distributions= predict(la, X_test,ret_distr=true)
1×20 Matrix{Distributions.Categorical{Float64, Vector{Float64}}}:
 Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569184, 0.196066, 0.0296796, 0.717336])  …  Distributions.Categorical{Float64, Vector{Float64}}(support=Base.OneTo(4), p=[0.0569634, 0.195727, 0.0296449, 0.717665])

then we transform the categorical distributions into Bernoulli distributions by taking only the probability of the class of interest, for example the third one.

using Distributions
bernoulli_distributions = [Bernoulli(p.p[3]) for p in vec(predicted_distributions)]
20-element Vector{Bernoulli{Float64}}:
 Bernoulli{Float64}(p=0.029679590887034743)
 Bernoulli{Float64}(p=0.6682373773598078)
 Bernoulli{Float64}(p=0.20912995228011141)
 Bernoulli{Float64}(p=0.20913322913224044)
 Bernoulli{Float64}(p=0.02971989045895732)
 Bernoulli{Float64}(p=0.668431087463204)
 Bernoulli{Float64}(p=0.03311710703617972)
 Bernoulli{Float64}(p=0.20912981531862682)
 Bernoulli{Float64}(p=0.11273726979027407)
 Bernoulli{Float64}(p=0.2490744632745955)
 Bernoulli{Float64}(p=0.029886357844211404)
 Bernoulli{Float64}(p=0.02965323602487074)
 Bernoulli{Float64}(p=0.1126799374664026)
 Bernoulli{Float64}(p=0.11278538625980777)
 Bernoulli{Float64}(p=0.6683139127616431)
 Bernoulli{Float64}(p=0.029644435143197145)
 Bernoulli{Float64}(p=0.11324691083703237)
 Bernoulli{Float64}(p=0.6681422555922787)
 Bernoulli{Float64}(p=0.668424345470233)
 Bernoulli{Float64}(p=0.029644891255330787)

Now we can use Calibration_Plot to see the level of calibration of the neural network

plt = Calibration_Plot(la,hcat(y_onehot_test...)[3,:],bernoulli_distributions;n_bins = 10);

The plot is peaked around 0.7.

A possible reason is that class 3 is relatively easy for the model to identify from the other classes, although it remains a bit underconfident in its predictions. Another reason for the peak may be the lack of cases where the predicted probability is lower (e.g., around 0.5), which could indicate that the network has not encountered ambiguous or difficult-to-classify examples for such class. This once again might be because either class 3 has distinct features that the model can easily learn, leading to fewer uncertain predictions, or is a consequence of the limited dataset.

We can measure how sharp the neural network is by computing the sharpness score

sharpnessclassification(hcat(yonehottest…)[3,:],vec(bernoullidistributions))

```

The neural network seems to be able to correctly classify the majority of samples not belonging to class 3 with a relative high confidence, but remains more uncertain when he encounter examples belonging to class 3.