Counterfactual Rule Explanations

Global and Local Explanations through Surrogate Trees (T-CREx)

counterfactuals
explainable AI
global
decision trees
Julia

New feature in CounterfactualExplanations.jl: Counterfactual Metarules for Local and Global Recourse (Bewley, Amoukou, et al. 2024)

Published

September 24, 2024

Counterfactual Explanations were originally proposed as a framework for generating local explanations (Wachter, Mittelstadt, and Russell 2017). Local explanations are specific to individual samples (Molnar 2022). From the perspective of Algorithmic Recourse, also sometimes referred to as Individual Recourse, local explanations are desirable, because ideally we want to be able to generate recourse recommendations that are tailored to individuals. When we are instead primarily interested in explaining the general behaviour of opaque models, then local explanations may not be ideal, because they are subject to idiosyncrasy: any given local explanation inherently depends on characteristics that may uniquely apply to the individual in question. Consequently, there has recently been growing interest in generating global or group-level explanations through counterfactuals (Kanamori et al. 2022; Ley, Mishra, and Magazzeni 2023; Warren et al. 2024).

Although CounterfactualExplanations.jl supports generating (multiple) counterfactuals for multiple instances at once, the package has so far lacked support for any of these more sophisticated approaches to generate group-level explanations. To address this gap, we have added support for the most novel and performant approach in version v1.3.0. This new approach called T-CREx was presented at ICML 2024, where we had the pleasure of meeting and speaking to Tom Bewley, the corresponding author of the paper (Bewley, I. Amoukou, et al. 2024).

This post introduces the implementation of the TCRExGenerator from the ground up. Since the approach and implementation depart from the existing focus of CounterfactualExplanations.jl on gradient-based generators, we expect future changes and extensions to this first implementation, which will be covered in future a follow-up blog post.

Background

T-CREx stands for Trees for Counterfactual Rule Explanation (Bewley, I. Amoukou, et al. 2024). As the name implies, the approach relies on a tree-based surrogate model to generate counterfactual explanations. The authors make the distinction between Counterfactual Point Explanations (CPE) and Counterfactual Rule Explanations (CRE). Most readers will already be familiar with the concept of CPE, since this all of our existing counterfactual generators fall into this category: for some individual (factual), they generate an associated counterfactual point in some target domain. Point explanation have a very explicit nature, since they prescribe exact feature values for the counterfactual. Rule explanations (CRE), on the other hand, are more flexible in nature: they merely put bounds on the values of certain features, usually a subset of all features.

Example: CPE vs. CRE

Suppose we want to provide recourse to an unsuccessful job applicant with 3 years of relevant job experience and a 7.1/10 GPA in their bachelor’s degree. A valid CPE might be that in order to secure the job, the individual should have exactly 4 years of relevant job experience and a 8.5/10 GPA. A corresponding CRE might be that the individual should have acieved a GPA of 9.0/10 or higher in order to qualify for the job, independent of prior job experience.

Implementation

As mentioned above, the basic implementation of T-CREx departs from the existing design of our package. This is primarily due to the fact that T-CREx introduces a new concept of Counterfactual Rule Explanations (CRE) described in Section 1. We have therefore chosen to experiment with a new API for this generator. As with all other generators, the TCRExGenerator has its own concrete type. Contrary to existing generators, this type can be called directly on a specified target label, data and model, reflecting its global nature: contrary to existing generators, T-CREx does not need a factual as a reference and starting point to generate a corresponding CPE. Rule explanations for a given target class can be generated using only the model and corresponding training data. The learned CRE can then be applied to individual instances, which effectively boils down to choosing the optimal rule for said individual from a set of learned rules. If this all sounds a bit confusing at this point, read on in Section 2.1 for a concrete usage example.

Breaking Changes Expected

Work on this feature is still in its very early stages and breaking changes should be expected. The introduction of this new generator introduces new concepts such as global counterfactual explanations that are not explained anywhere else in this documentation. If you want to use this generator, please make sure you are familiar with the related literature.

Usage Example

The implementation of the TCRExGenerator depends on DecisionTree.jl. For the time being, we have decided to not add a strong dependency on DecisionTree.jl to the package. Instead, the functionality of the TCRExGenerator is made available through the DecisionTreeExt extension, which will be loaded conditionally on loading the DecisionTree.jl (see Julia docs for more details extensions):

Code
using DecisionTree

Next, we load some other dependencies used in this tutorial:

Code
using CategoricalArrays
using CounterfactualExplanations
using CounterfactualExplanations.Generators
using CounterfactualExplanations.Models
using Plots
using Random
using TaijaPlotting

# Setting up color palette as in paper:
col_pal = palette(:seaborn_bright)[[4,1,2,3,6,5,7,8,9]];

# Set seed for reproducibility:
Random.seed!(2024)

Let us first load set up the problem by loading some data. To reproduce the example in Bewley, I. Amoukou, et al. (2024) as accurately as possible, we use Python’s scikit-learn to load the synthetic data:

Code
using CondaPkg; CondaPkg.add("scikit-learn");
using PythonCall;
skd = pyimport("sklearn.datasets");
n = 5000
X, y = skd.make_moons(n_samples=n, noise=0.3, random_state=0)
X = pyconvert(Matrix, X) |> permutedims |> x -> Float32.(x)
y = pyconvert(Vector, y)

Next, we wrap the data in a CounterfactuaData container, fit a simple classification model to the data and store the model prediction for the entire training dataset (we need those to train the tree-based surrogate model).

Code
# Counteractual data and model:
data = CounterfactualData(X, y)
flux_training_params.batchsize = 100
M = fit_model(data, :MLP)

Finally, we determine a target and factual class and choose a random sample from the factual class:

Code
target = 1
factual = 0
chosen = rand(findall(predict_label(M, data) .== factual))
x = select_factual(data, chosen)

Next, we instantiate the generator much like any other counterfactual generator in our package:

Code
ρ = 0.02        # feasibility threshold (see Bewley et al. (2024))
τ = 0.9         # accuracy threshold (see Bewley et al. (2024))
generator = Generators.TCRExGenerator=ρ, τ=τ)

As mentioned above, this instance is callable. In particular, we can call it on the given target, data and model to generate a (global) counterfactual rule epxlanation (CRE) as follows:

Code
cre = generator(target, data, M)        # counterfactual rule explanation (global)

To generate a local explanation (CPE) for our factual instance x, the learned cre can itself be called on x:

Code
idx, optimal_rule = cre(x)              # counterfactual point explanation (local)

What we have seen so far corresponds to the current user-facing API for the T-CREx generator. If that is all you were after when opening this blog post, you can stop here. Otherwise, read on to find out how exactly the approach works, step-by-step.

Worked Example from Bewley, I. Amoukou, et al. (2024)

To make better sense of this, we will now go through the worked example presented in Bewley, I. Amoukou, et al. (2024). For this purpose, we need to make the functions of the DecisionTreeExt extension available.

Private API

Please note that of the DecisionTreeExt extension is loaded here purely for demonstrative purposes. You should not load the extension like this in your own work.

Code
DTExt = Base.get_extension(CounterfactualExplanations, :DecisionTreeExt)

(a) Tree-based surrogate model

In the first step, we train a tree-based surrogate model based on the data and the black-box model M. Specifically, the surrogate model is trained on pairs of observed input data and the labels predicted by the black-box model: \(\{(x, M(x))\}_{1\leq i \leq n}\).

Oracle Black-Box

As in the paper, we assume here that the black-box model is an oracle with perfect accuracy. This is done purely to stay as close as possible to the example in the paper.

Following Bewley, I. Amoukou, et al. (2024), we impose a minimum number of samples per leaf to ensure counterfactual feasibility (also often referred to as plausibility). This number is computed under the hood and based on the generator.ρ field of the TCRExGenerator, which can be used to specify the minimum fraction of all samples that is contained by any given rule.

Code
# Surrogate:
Xtrain = permutedims(X)
ytrain = categorical(y)
fx = ytrain                 # assume perfect accuracy
model, fitresult = DTExt.grow_surrogate(generator, Xtrain, fx)
M_sur = CounterfactualExplanations.DecisionTreeModel(model; fitresult=fitresult)

We can reassure ourselves that the feasibility constraint is indeed respected:

Code
# Extract rules:
R = DTExt.extract_rules(fitresult[1])

# Compute feasibility and accuracy:
feas = DTExt.rule_feasibility.(R, (X,))
@assert minimum(feas) >= ρ
@info "Minimum fraction of samples across all rules is $(round(minimum(feas), digits=3))"
acc_factual = DTExt.rule_accuracy.(R, (X,), (fx,), (factual,))
acc_target = DTExt.rule_accuracy.(R, (X,), (fx,), (target,))
@assert all(acc_target .+ acc_factual .== 1.0)
[ Info: Minimum fraction of samples across all rules is 0.02
Code
plt = plot(data; ms=2, markerstrokewidth=0, size=(500, 500), palette=col_pal, alpha=0.5)
rectangle(w, h, x, y) = Shape(x .+ [0,w,w,0], y .+ [0,0,h,h])
function plot_grid!(p, grid)
    for (i, (bounds_x, bounds_y)) in enumerate(grid)
        lbx, ubx = bounds_x
        lby, uby  = bounds_y
        lbx = maximum([lbx, minimum(X[1, :])])
        lby = maximum([lby, minimum(X[2, :])])
        ubx = minimum([ubx, maximum(X[1, :])])
        uby = minimum([uby, maximum(X[2, :])])
        plot!(
            p,
            rectangle(ubx - lbx, uby - lby, lbx, lby);
            fillcolor="black",
            fillalpha=0.0,
            label=nothing,
            lw=2, palette=col_pal
        )
    end
end
plot_grid!(plt, R)
plt
Figure 1: Tree-based surrogate model

(b) Maximal-valid rules

From the complete set of rules derived from the surrogate tree, we can derive the maximal-valid rules next. Intuitively, “a maximal-valid rule is one that cannot be made any larger without violating the validity conditions”, where validity is defined in terms of both feasibility (generator.ρ) and accuracy (generator.τ).

Code
R_max = DTExt.max_valid(R, X, fx, target, τ)
feas_max = DTExt.rule_feasibility.(R_max, (X,))
acc_max = DTExt.rule_accuracy.(R_max, (X,), (fx,), (target,))
p1 = deepcopy(plt)
function plot_surr!(plt)
    for (i, rule) in enumerate(R_max)
        ubx, uby = minimum([rule[1][2], maximum(X[1, :])]),
        minimum([rule[2][2], maximum(X[2, :])])
        lbx, lby = maximum([rule[1][1], minimum(X[1, :])]),
        maximum([rule[2][1], minimum(X[2, :])])
        _feas = round(feas_max[i]; digits=2)
        _n = Int(round(feas_max[i] * n; digits=2))
        _acc = round(acc_max[i]; digits=2)
        @info "Rectangle R$i with feasibility $(_feas) (n≈$(_n)) and accuracy $(_acc)"
        lab = "R$i (ρ̂=$(_feas), τ̂=$(_acc))"
        plot!(plt, rectangle(ubx-lbx,uby-lby,lbx,lby), opacity=.5, color=i+2, label=lab, palette=col_pal)
    end
end
plot_surr!(p1)
p1
[ Info: Rectangle R1 with feasibility 0.22 (n≈1094) and accuracy 0.98
[ Info: Rectangle R2 with feasibility 0.07 (n≈335) and accuracy 0.92
[ Info: Rectangle R3 with feasibility 0.08 (n≈383) and accuracy 0.98
[ Info: Rectangle R4 with feasibility 0.04 (n≈211) and accuracy 0.94
Figure 2: Maximal-valid rules.

(c) Induced grid partition

Based on the set of maximal-valid rules, we compute and plot the induced grid partition below.

Code
_grid = DTExt.induced_grid(R_max)

plt = plot(data; ms=2, markerstrokewidth=0, size=(500, 500), palette=col_pal, alpha=0.1)
p2 = deepcopy(plt)
plot_surr!(p2)
plot_grid!(p2, _grid)
p2
[ Info: Rectangle R1 with feasibility 0.22 (n≈1094) and accuracy 0.98
[ Info: Rectangle R2 with feasibility 0.07 (n≈335) and accuracy 0.92
[ Info: Rectangle R3 with feasibility 0.08 (n≈383) and accuracy 0.98
[ Info: Rectangle R4 with feasibility 0.04 (n≈211) and accuracy 0.94
Figure 3: Induced grid partition.

(d) Grid cell prototypes

Next, we pick prototypes from each cell in the induced grid. By setting pick_arbitrary=false here we enfore that prototypes correspond to cell centroids, which is not necessary. For each prototype, we compute the corresponding CRE, which is indicated by the color of the large markers in the figure below:

Code
xs = DTExt.prototype.(_grid, (X,); pick_arbitrary=false)
Rᶜ = DTExt.cre.((R_max,), xs, (X,); return_index=true) 
p3 = deepcopy(p2)
scatter!(p3, eachrow(hcat(xs...))..., ms=10, label=nothing, color=Rᶜ.+2)
p3
Figure 4: Grid cell prototypes.

(e) - (f) Global CE representation

Based on the prototypes and their corresponding rule assignments, we fit a CART classification tree with restricted feature thresholds. Specificically, features thresholds are restricted to the partition bounds induced by the set of maximal-valid rules as in Bewley, I. Amoukou, et al. (2024). The figure below shows the resulting global CE representation (i.e. the metarules).

Code
bounds = DTExt.partition_bounds(R_max)
tree = DTExt.classify_prototypes(hcat(xs...)', Rᶜ, bounds)
R_final, labels = DTExt.extract_leaf_rules(tree) 
p4 = deepcopy(plt)
for (i, rule) in enumerate(R_final)
    ubx, uby = minimum([rule[1][2], maximum(X[1, :])]),
    minimum([rule[2][2], maximum(X[2, :])])
    lbx, lby = maximum([rule[1][1], minimum(X[1, :])]),
    maximum([rule[2][1], minimum(X[2, :])])
    plot!(
        p4,
        rectangle(ubx - lbx, uby - lby, lbx, lby);
        fillalpha=0.5,
        label=nothing,
        color=labels[i] + 2
    )
end
p4
Figure 5: Global CE representation.

(g) Local CE example

To generate a local explanation based on the global CE representation, we simply apply the CART decision tree classifier from the previous step to our factual:

Code
optimal_rule = apply_tree(tree, vec(x))
p5 = deepcopy(p2)
scatter!(p5, [x[1]], [x[2]], ms=10, color=2+optimal_rule, label="Local CE (move to R$optimal_rule)")
p5
Figure 6: Local CE example.

Conclusion

This blog post has introduced the TCRExGenerator, a novel approach towards generating global counterfactual explanations. We have explained the high-level details of this approach proposed by Bewley, I. Amoukou, et al. (2024), presented the current version of the corresponding API in CounterfactualExplanations.jl and, finally, described the details of the implementation. Since this approach and implementation represent a departure from the existing focus and design of our package, we expect future breaking changes and extensions.

Acknowledgements

Patrick is grateful to Tom Bewley for an interesting chat during ICML and his support with this implementation.

References

Bewley, Tom, Salim I. Amoukou, Saumitra Mishra, Daniele Magazzeni, and Manuela Veloso. 2024. “Counterfactual Metarules for Local and Global Recourse.” https://arxiv.org/abs/2405.18875.
Bewley, Tom, Salim I. Amoukou, Saumitra Mishra, Daniele Magazzeni, and Manuela Veloso. 2024. “Counterfactual Metarules for Local and Global Recourse.” In Proceedings of the 41st International Conference on Machine Learning, edited by Ruslan Salakhutdinov, Zico Kolter, Katherine Heller, Adrian Weller, Nuria Oliver, Jonathan Scarlett, and Felix Berkenkamp, 235:3707–24. Proceedings of Machine Learning Research. PMLR. https://proceedings.mlr.press/v235/bewley24a.html.
Kanamori, Kentaro, Takuya Takagi, Ken Kobayashi, and Yuichi Ike. 2022. “Counterfactual Explanation Trees: Transparent and Consistent Actionable Recourse with Decision Trees.” In Proceedings of the 25th International Conference on Artificial Intelligence and Statistics, edited by Gustau Camps-Valls, Francisco J. R. Ruiz, and Isabel Valera, 151:1846–70. Proceedings of Machine Learning Research. PMLR. https://proceedings.mlr.press/v151/kanamori22a.html.
Ley, Dan, Saumitra Mishra, and Daniele Magazzeni. 2023. GLOBE-CE: A Translation Based Approach for Global Counterfactual Explanations.” In Proceedings of the 40th International Conference on Machine Learning, edited by Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, 202:19315–42. Proceedings of Machine Learning Research. PMLR. https://proceedings.mlr.press/v202/ley23a.html.
Molnar, Christoph. 2022. Interpretable Machine Learning: A Guide for Making Black Box Models Explainable. 2nd ed. https://christophm.github.io/interpretable-ml-book.
Wachter, Sandra, Brent Mittelstadt, and Chris Russell. 2017. “Counterfactual Explanations Without Opening the Black Box: Automated Decisions and the GDPR.” Harv. JL & Tech. 31: 841. https://doi.org/10.2139/ssrn.3063289.
Warren, Greta, Eoin Delaney, Christophe Guéret, and Mark T. Keane. 2024. “Explaining Multiple Instances Counterfactually:user Tests of group-Counterfactuals For XAI.” In Case-Based Reasoning Research and Development: 32nd International Conference, ICCBR 2024, Merida, Mexico, July 1–4, 2024, Proceedings, 206–22. Berlin, Heidelberg: Springer-Verlag. https://doi.org/10.1007/978-3-031-63646-2_14.

Citation

BibTeX citation:
@online{altmeyer2024,
  author = {Altmeyer, Patrick},
  title = {Counterfactual {Rule} {Explanations}},
  date = {2024-09-24},
  url = {https://www.taija.org/blog/posts/counterfactual-rule-explanations/},
  langid = {en}
}
For attribution, please cite this work as:
Altmeyer, Patrick. 2024. “Counterfactual Rule Explanations.” September 24, 2024. https://www.taija.org/blog/posts/counterfactual-rule-explanations/.