Performance Benchmarks
In the previous tutorial, we have seen how counterfactual explanations can be evaluated. An important follow-up task is to compare the performance of different counterfactual generators is an important task. Researchers can use benchmarks to test new ideas they want to implement. Practitioners can find the right counterfactual generator for their specific use case through benchmarks. In this tutorial, we will see how to run benchmarks for counterfactual generators.
Post Hoc Benchmarking
We begin by continuing the discussion from the previous tutorial: suppose you have generated multiple counterfactual explanations for multiple individuals, like below:
# Factual and target:
n_individuals = 5
ids = rand(findall(predict_label(M, counterfactual_data) .== factual), n_individuals)
xs = select_factual(counterfactual_data, ids)
ces = generate_counterfactual(xs, target, counterfactual_data, M, generator; num_counterfactuals=5)
You may be interested in comparing the outcomes across individuals. To benchmark the various counterfactual explanations using default evaluation measures, you can simply proceed as follows:
bmk = benchmark(ces)
Under the hood, the benchmark(counterfactual_explanations::Vector{CounterfactualExplanation})
uses CounterfactualExplanations.Evaluation.evaluate(ce::CounterfactualExplanation)
to generate a Benchmark
object, which contains the evaluation in its most granular form as a DataFrame
.
Working with Benchmark
s
For convenience, the DataFrame
containing the evaluation can be returned by simply calling the Benchmark
object. By default, the aggregated evaluation measures across id
(in line with the default behaviour of evaluate
).
bmk()
15ร7 DataFrame
Row โ sample variable value generator โฏ
โ Base.UUID String Float64 Symbol โฏ
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 239104d0-f59f-11ee-3d0c-d1db071927ff distance 3.17243 GradientBase โฏ
2 โ 239104d0-f59f-11ee-3d0c-d1db071927ff redundancy 0.0 GradientBase
3 โ 239104d0-f59f-11ee-3d0c-d1db071927ff validity 1.0 GradientBase
4 โ 2398b3e2-f59f-11ee-3323-13d53fb7e75b distance 3.07148 GradientBase
5 โ 2398b3e2-f59f-11ee-3323-13d53fb7e75b redundancy 0.0 GradientBase โฏ
6 โ 2398b3e2-f59f-11ee-3323-13d53fb7e75b validity 1.0 GradientBase
7 โ 2398b916-f59f-11ee-3f13-bd00858a39af distance 3.62159 GradientBase
8 โ 2398b916-f59f-11ee-3f13-bd00858a39af redundancy 0.0 GradientBase
9 โ 2398b916-f59f-11ee-3f13-bd00858a39af validity 1.0 GradientBase โฏ
10 โ 2398bce8-f59f-11ee-37c1-ef7c6de27b6b distance 2.62783 GradientBase
11 โ 2398bce8-f59f-11ee-37c1-ef7c6de27b6b redundancy 0.0 GradientBase
12 โ 2398bce8-f59f-11ee-37c1-ef7c6de27b6b validity 1.0 GradientBase
13 โ 2398c08a-f59f-11ee-175b-81c155750752 distance 2.91985 GradientBase โฏ
14 โ 2398c08a-f59f-11ee-175b-81c155750752 redundancy 0.0 GradientBase
15 โ 2398c08a-f59f-11ee-175b-81c155750752 validity 1.0 GradientBase
4 columns omitted
To retrieve the granular dataset, simply do:
bmk(agg=nothing)
75ร8 DataFrame
Row โ sample num_counterfactual variable v โฏ
โ Base.UUID Int64 String F โฏ
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 1 distance 3 โฏ
2 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 2 distance 3
3 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 3 distance 3
4 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 4 distance 3
5 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 5 distance 3 โฏ
6 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 1 redundancy 0
7 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 2 redundancy 0
8 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 3 redundancy 0
9 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 4 redundancy 0 โฏ
10 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 5 redundancy 0
11 โ 239104d0-f59f-11ee-3d0c-d1db071927ff 1 validity 1
โฎ โ โฎ โฎ โฎ โฑ
66 โ 2398c08a-f59f-11ee-175b-81c155750752 1 redundancy 0
67 โ 2398c08a-f59f-11ee-175b-81c155750752 2 redundancy 0 โฏ
68 โ 2398c08a-f59f-11ee-175b-81c155750752 3 redundancy 0
69 โ 2398c08a-f59f-11ee-175b-81c155750752 4 redundancy 0
70 โ 2398c08a-f59f-11ee-175b-81c155750752 5 redundancy 0
71 โ 2398c08a-f59f-11ee-175b-81c155750752 1 validity 1 โฏ
72 โ 2398c08a-f59f-11ee-175b-81c155750752 2 validity 1
73 โ 2398c08a-f59f-11ee-175b-81c155750752 3 validity 1
74 โ 2398c08a-f59f-11ee-175b-81c155750752 4 validity 1
75 โ 2398c08a-f59f-11ee-175b-81c155750752 5 validity 1 โฏ
5 columns and 54 rows omitted
Since benchmarks return a DataFrame
object on call, post-processing is straightforward. For example, we could use Tidier.jl
:
using Tidier
@chain bmk() begin
@filter(variable == "distance")
@select(sample, variable, value)
end
5ร3 DataFrame
Row โ sample variable value
โ Base.UUID String Float64
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 239104d0-f59f-11ee-3d0c-d1db071927ff distance 3.17243
2 โ 2398b3e2-f59f-11ee-3323-13d53fb7e75b distance 3.07148
3 โ 2398b916-f59f-11ee-3f13-bd00858a39af distance 3.62159
4 โ 2398bce8-f59f-11ee-37c1-ef7c6de27b6b distance 2.62783
5 โ 2398c08a-f59f-11ee-175b-81c155750752 distance 2.91985
Metadata for Counterfactual Explanations
Benchmarks always report metadata for each counterfactual explanation, which is automatically inferred by default. The default metadata concerns the explained model
and the employed generator
. In the current example, we used the same model and generator for each individual:
@chain bmk() begin
@group_by(sample)
@select(sample, model, generator)
@summarize(model=first(model),generator=first(generator))
@ungroup
end
5ร3 DataFrame
Row โ sample model โฏ
โ Base.UUID Symbol โฏ
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 239104d0-f59f-11ee-3d0c-d1db071927ff FluxModel(Chain(Dense(2 => 2)), โฆ โฏ
2 โ 2398b3e2-f59f-11ee-3323-13d53fb7e75b FluxModel(Chain(Dense(2 => 2)), โฆ
3 โ 2398b916-f59f-11ee-3f13-bd00858a39af FluxModel(Chain(Dense(2 => 2)), โฆ
4 โ 2398bce8-f59f-11ee-37c1-ef7c6de27b6b FluxModel(Chain(Dense(2 => 2)), โฆ
5 โ 2398c08a-f59f-11ee-175b-81c155750752 FluxModel(Chain(Dense(2 => 2)), โฆ โฏ
1 column omitted
Metadata can also be provided as an optional key argument.
meta_data = Dict(
:generator => "Generic",
:model => "MLP",
)
meta_data = [meta_data for i in 1:length(ces)]
bmk = benchmark(ces; meta_data=meta_data)
@chain bmk() begin
@group_by(sample)
@select(sample, model, generator)
@summarize(model=first(model),generator=first(generator))
@ungroup
end
5ร3 DataFrame
Row โ sample model generator
โ Base.UUID String String
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 27fae496-f59f-11ee-2c30-f35d1025a6d4 MLP Generic
2 โ 27fdcc6a-f59f-11ee-030b-152c9794c5f1 MLP Generic
3 โ 27fdd04a-f59f-11ee-2010-e1732ff5d8d2 MLP Generic
4 โ 27fdd340-f59f-11ee-1d20-050a69dcacef MLP Generic
5 โ 27fdd5fc-f59f-11ee-02e8-d198e436abb3 MLP Generic
Ad Hoc Benchmarking
So far we have assumed the following workflow:
- Fit some machine learning model.
- Generate counterfactual explanations for some individual(s) (
generate_counterfactual
). - Evaluate and benchmark them (
benchmark(ces::Vector{CounterfactualExplanation})
).
In many cases, it may be preferable to combine these steps. To this end, we have added support for two scenarios of Ad Hoc Benchmarking.
Pre-trained Models
In the first scenario, it is assumed that the machine learning models have been pre-trained and so the workflow can be summarized as follows:
- Fit some machine learning model(s).
- Generate counterfactual explanations and benchmark them.
We suspect that this is the most common workflow for practitioners who are interested in benchmarking counterfactual explanations for the pre-trained machine learning models. Letโs go through this workflow using a simple example. We first train some models and store them in a dictionary:
models = Dict(
:MLP => fit_model(counterfactual_data, :MLP),
:Linear => fit_model(counterfactual_data, :Linear),
)
Next, we store the counterfactual generators of interest in a dictionary as well:
generators = Dict(
:Generic => GenericGenerator(),
:Gravitational => GravitationalGenerator(),
:Wachter => WachterGenerator(),
:ClaPROAR => ClaPROARGenerator(),
)
Then we can run a benchmark for individual(s) x
, a pre-specified target
and counterfactual_data
as follows:
bmk = benchmark(x, target, counterfactual_data; models=models, generators=generators)
In this case, metadata is automatically inferred from the dictionaries:
@chain bmk() begin
@filter(variable == "distance")
@select(sample, variable, value, model, generator)
end
8ร5 DataFrame
Row โ sample variable value model โฏ
โ Base.UUID String Float64 Tupleโฆ โฏ
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 2cba5eee-f59f-11ee-1844-cbc7a8372a38 distance 4.38877 (:Linear, Flux โฏ
2 โ 2cd740fe-f59f-11ee-35c3-1157eb1b7583 distance 4.17021 (:Linear, Flux
3 โ 2cd741e2-f59f-11ee-2b09-0d55ef9892b9 distance 4.31145 (:Linear, Flux
4 โ 2cd7420c-f59f-11ee-1996-6fa75e23bb57 distance 4.17035 (:Linear, Flux
5 โ 2cd74234-f59f-11ee-0ad0-9f21949f5932 distance 5.73182 (:MLP, FluxMod โฏ
6 โ 2cd7425c-f59f-11ee-3eb4-af34f85ffd3d distance 5.50606 (:MLP, FluxMod
7 โ 2cd7427a-f59f-11ee-10d3-a1df6c8dc125 distance 5.2114 (:MLP, FluxMod
8 โ 2cd74298-f59f-11ee-32d1-f501c104fea8 distance 5.3623 (:MLP, FluxMod
2 columns omitted
Everything at once
Researchers, in particular, may be interested in combining all steps into one. This is the second scenario of Ad Hoc Benchmarking:
- Fit some machine learning model(s), generate counterfactual explanations and benchmark them.
It involves calling benchmark
directly on counterfactual data (the only positional argument):
bmk = benchmark(counterfactual_data)
This will use the default models from standard_models_catalogue
and train them on the data. All available generators from generator_catalogue
will also be used:
@chain bmk() begin
@filter(variable == "validity")
@select(sample, variable, value, model, generator)
end
200ร5 DataFrame
Row โ sample variable value model genera โฏ
โ Base.UUID String Float64 Symbol Symbol โฏ
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear gravit โฏ
2 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear growin
3 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear revise
4 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear clue
5 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear probe โฏ
6 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear dice
7 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear clapro
8 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear wachte
9 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear generi โฏ
10 โ 32d1817e-f59f-11ee-152f-a30b18c2e6f7 validity 1.0 Linear greedy
11 โ 32d255e8-f59f-11ee-3e8d-a9e9f6e23ea8 validity 1.0 Linear gravit
โฎ โ โฎ โฎ โฎ โฎ โฑ
191 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP gravit
192 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP growin โฏ
193 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP revise
194 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP clue
195 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP probe
196 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP dice โฏ
197 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP clapro
198 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP wachte
199 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP generi
200 โ 3382d08a-f59f-11ee-10b3-f7d18cf7d3b5 validity 1.0 MLP greedy โฏ
1 column and 179 rows omitted
Optionally, you can instead provide a dictionary of models
and generators
as before. Each value in the models
dictionary should be one of two things:
- Either be an object
M
of typeAbstractModel
that implements theModels.train
method. - Or a
DataType
that can be called onCounterfactualData
to create an objectM
as in (a).
Multiple Datasets
Benchmarks are run on single instances of type CounterfactualData
. This is our design choice for two reasons:
- We want to avoid the loops inside the
benchmark
method(s) from getting too nested and convoluted. - While it is straightforward to infer metadata for models and generators, this is not the case for datasets.
Fortunately, it is very easy to run benchmarks for multiple datasets anyway, since Benchmark
instances can be concatenated. To see how, letโs consider an example involving multiple datasets, models and generators:
# Data:
datasets = Dict(
:moons => CounterfactualData(load_moons()...),
:circles => CounterfactualData(load_circles()...),
)
# Models:
models = Dict(
:MLP => FluxModel,
:Linear => Linear,
)
# Generators:
generators = Dict(
:Generic => GenericGenerator(),
:Greedy => GreedyGenerator(),
)
Then we can simply loop over the datasets and eventually concatenate the results like so:
using CounterfactualExplanations.Evaluation: distance_measures
bmks = []
for (dataname, dataset) in datasets
bmk = benchmark(dataset; models=models, generators=generators, measure=distance_measures)
push!(bmks, bmk)
end
bmk = vcat(bmks[1], bmks[2]; ids=collect(keys(datasets)))
When ids
are supplied, then a new id column is added to the evaluation data frame that contains unique identifiers for the different benchmarks. The optional idcol_name
argument can be used to specify the name for that indicator column (defaults to "dataset"
):
@chain bmk() begin
@group_by(dataset, generator)
@filter(model == :MLP)
@filter(variable == "distance_l1")
@summarize(L1_norm=mean(value))
@ungroup
end
4ร3 DataFrame
Row โ dataset generator L1_norm
โ Symbol Symbol Float32
โโโโโโผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
1 โ moons Generic 1.56555
2 โ moons Greedy 0.819269
3 โ circles Generic 1.83524
4 โ circles Greedy 0.498953