Counterfactual Rule Explanations

Global and Local Explanations through Surrogate Trees (T-CREx)

counterfactuals
explainable AI
global
decision trees
Julia

New feature in CounterfactualExplanations.jl: Counterfactual Metarules for Local and Global Recourse (Bewley, Amoukou, et al. 2024)

Published

September 24, 2024

Counterfactual Explanations were originally proposed as a framework for generating local explanations (Wachter, Mittelstadt, and Russell 2017). Local explanations are specific to individual samples (Molnar 2022). From the perspective of Algorithmic Recourse, also sometimes referred to as Individual Recourse, local explanations are desirable, because ideally we want to be able to generate recourse recommendations that are tailored to individuals. When we are instead primarily interested in explaining the general behaviour of opaque models, then local explanations may not be ideal, because they are subject to idiosyncrasy: any given local explanation inherently depends on characteristics that may uniquely apply to the individual in question. Consequently, there has recently been growing interest in generating global or group-level explanations through counterfactuals (Kanamori et al. 2022; Ley, Mishra, and Magazzeni 2023; Warren et al. 2024).

Although CounterfactualExplanations.jl supports generating (multiple) counterfactuals for multiple instances at once, the package has so far lacked support for any of these more sophisticated approaches to generate group-level explanations. To address this gap, we have added support for the most novel and performant approach in version v1.3.0. This new approach called T-CREx was presented at ICML 2024, where we had the pleasure of meeting and speaking to Tom Bewley, the corresponding author of the paper (Bewley, I. Amoukou, et al. 2024).

This post introduces the implementation of the TCRExGenerator from the ground up. Since the approach and implementation depart from the existing focus of CounterfactualExplanations.jl on gradient-based generators, we expect future changes and extensions to this first implementation, which will be covered in future a follow-up blog post.

Background

T-CREx stands for Trees for Counterfactual Rule Explanation (Bewley, I. Amoukou, et al. 2024). As the name implies, the approach relies on a tree-based surrogate model to generate counterfactual explanations. The authors make the distinction between Counterfactual Point Explanations (CPE) and Counterfactual Rule Explanations (CRE). Most readers will already be familiar with the concept of CPE, since this all of our existing counterfactual generators fall into this category: for some individual (factual), they generate an associated counterfactual point in some target domain. Point explanation have a very explicit nature, since they prescribe exact feature values for the counterfactual. Rule explanations (CRE), on the other hand, are more flexible in nature: they merely put bounds on the values of certain features, usually a subset of all features.

Example: CPE vs. CRE

Suppose we want to provide recourse to an unsuccessful job applicant with 3 years of relevant job experience and a 7.1/10 GPA in their bachelor’s degree. A valid CPE might be that in order to secure the job, the individual should have exactly 4 years of relevant job experience and a 8.5/10 GPA. A corresponding CRE might be that the individual should have acieved a GPA of 9.0/10 or higher in order to qualify for the job, independent of prior job experience.

Implementation

As mentioned above, the basic implementation of T-CREx departs from the existing design of our package. This is primarily due to the fact that T-CREx introduces a new concept of Counterfactual Rule Explanations (CRE) described in Section 1. We have therefore chosen to experiment with a new API for this generator. As with all other generators, the TCRExGenerator has its own concrete type. Contrary to existing generators, this type can be called directly on a specified target label, data and model, reflecting its global nature: contrary to existing generators, T-CREx does not need a factual as a reference and starting point to generate a corresponding CPE. Rule explanations for a given target class can be generated using only the model and corresponding training data. The learned CRE can then be applied to individual instances, which effectively boils down to choosing the optimal rule for said individual from a set of learned rules. If this all sounds a bit confusing at this point, read on in Section 2.1 for a concrete usage example.

Breaking Changes Expected

Work on this feature is still in its very early stages and breaking changes should be expected. The introduction of this new generator introduces new concepts such as global counterfactual explanations that are not explained anywhere else in this documentation. If you want to use this generator, please make sure you are familiar with the related literature.

Usage Example

The implementation of the TCRExGenerator depends on DecisionTree.jl. For the time being, we have decided to not add a strong dependency on DecisionTree.jl to the package. Instead, the functionality of the TCRExGenerator is made available through the DecisionTreeExt extension, which will be loaded conditionally on loading the DecisionTree.jl (see Julia docs for more details extensions):

Code
using DecisionTree

Next, we load some other dependencies used in this tutorial:

Code
using CategoricalArrays
using CounterfactualExplanations
using CounterfactualExplanations.Generators
using CounterfactualExplanations.Models
using Plots
using Random
using TaijaPlotting

# Setting up color palette as in paper:
col_pal = palette(:seaborn_bright)[[4,1,2,3,6,5,7,8,9]];

# Set seed for reproducibility:
Random.seed!(2024)

Let us first load set up the problem by loading some data. To reproduce the example in Bewley, I. Amoukou, et al. (2024) as accurately as possible, we use Python’s scikit-learn to load the synthetic data:

Code
using CondaPkg; CondaPkg.add("scikit-learn");
using PythonCall;
skd = pyimport("sklearn.datasets");
n = 5000
X, y = skd.make_moons(n_samples=n, noise=0.3, random_state=0)
X = pyconvert(Matrix, X) |> permutedims |> x -> Float32.(x)
y = pyconvert(Vector, y)

Next, we wrap the data in a CounterfactuaData container, fit a simple classification model to the data and store the model prediction for the entire training dataset (we need those to train the tree-based surrogate model).

Code
# Counteractual data and model:
data = CounterfactualData(X, y)
flux_training_params.batchsize = 100
M = fit_model(data, :MLP)

Finally, we determine a target and factual class and choose a random sample from the factual class:

Code
target = 1
factual = 0
chosen = rand(findall(predict_label(M, data) .== factual))
x = select_factual(data, chosen)

Next, we instantiate the generator much like any other counterfactual generator in our package:

Code
ρ = 0.02        # feasibility threshold (see Bewley et al. (2024))
τ = 0.9         # accuracy threshold (see Bewley et al. (2024))
generator = Generators.TCRExGenerator=ρ, τ=τ)

As mentioned above, this instance is callable. In particular, we can call it on the given target, data and model to generate a (global) counterfactual rule epxlanation (CRE) as follows:

Code
cre = generator(target, data, M)        # counterfactual rule explanation (global)

To generate a local explanation (CPE) for our factual instance x, the learned cre can itself be called on x:

Code
idx, optimal_rule = cre(x)              # counterfactual point explanation (local)

What we have seen so far corresponds to the current user-facing API for the T-CREx generator. If that is all you were after when opening this blog post, you can stop here. Otherwise, read on to find out how exactly the approach works, step-by-step.

Worked Example from Bewley, I. Amoukou, et al. (2024)

To make better sense of this, we will now go through the worked example presented in Bewley, I. Amoukou, et al. (2024). For this purpose, we need to make the functions of the DecisionTreeExt extension available.

Private API

Please note that of the DecisionTreeExt extension is loaded here purely for demonstrative purposes. You should not load the extension like this in your own work.

Code
DTExt = Base.get_extension(CounterfactualExplanations, :DecisionTreeExt)

(a) Tree-based surrogate model

In the first step, we train a tree-based surrogate model based on the data and the black-box model M. Specifically, the surrogate model is trained on pairs of observed input data and the labels predicted by the black-box model: \(\{(x, M(x))\}_{1\leq i \leq n}\).

Oracle Black-Box

As in the paper, we assume here that the black-box model is an oracle with perfect accuracy. This is done purely to stay as close as possible to the example in the paper.

Following Bewley, I. Amoukou, et al. (2024), we impose a minimum number of samples per leaf to ensure counterfactual feasibility (also often referred to as plausibility). This number is computed under the hood and based on the generator.ρ field of the TCRExGenerator, which can be used to specify the minimum fraction of all samples that is contained by any given rule.

Code
# Surrogate:
Xtrain = permutedims(X)
ytrain = categorical(y)
fx = ytrain                 # assume perfect accuracy
model, fitresult = DTExt.grow_surrogate(generator, Xtrain, fx)
M_sur = CounterfactualExplanations.DecisionTreeModel(model; fitresult=fitresult)

We can reassure ourselves that the feasibility constraint is indeed respected:

Code
# Extract rules:
R = DTExt.extract_rules(fitresult[1])

# Compute feasibility and accuracy:
feas = DTExt.rule_feasibility.(R, (X,))
@assert minimum(feas) >= ρ
@info "Minimum fraction of samples across all rules is $(round(minimum(feas), digits=3))"
acc_factual = DTExt.rule_accuracy.(R, (X,), (fx,), (factual,))
acc_target = DTExt.rule_accuracy.(R, (X,), (fx,), (target,))
@assert all(acc_target .+ acc_factual .== 1.0)
[ Info: Minimum fraction of samples across all rules is 0.02
Code
plt = plot(data; ms=2, markerstrokewidth=0, size=(500, 500), palette=col_pal, alpha=0.5)
rectangle(w, h, x, y) = Shape(x .+ [0,w,w,0], y .+ [0,0,h,h])
function plot_grid!(p, grid)
    for (i, (bounds_x, bounds_y)) in enumerate(grid)
        lbx, ubx = bounds_x
        lby, uby  = bounds_y
        lbx = maximum([lbx, minimum(X[1, :])])
        lby = maximum([lby, minimum(X[2, :])])
        ubx = minimum([ubx, maximum(X[1, :])])
        uby = minimum([uby, maximum(X[2, :])])
        plot!(
            p,
            rectangle(ubx - lbx, uby - lby, lbx, lby);
            fillcolor="black",
            fillalpha=0.0,
            label=nothing,
            lw=2, palette=col_pal
        )
    end
end
plot_grid!(plt, R)
plt
Figure 1: Tree-based surrogate model

(b) Maximal-valid rules

From the complete set of rules derived from the surrogate tree, we can derive the maximal-valid rules next. Intuitively, “a maximal-valid rule is one that cannot be made any larger without violating the validity conditions”, where validity is defined in terms of both feasibility (generator.ρ) and accuracy (generator.τ).

Code
R_max = DTExt.max_valid(R, X, fx, target, τ)
feas_max = DTExt.rule_feasibility.(R_max, (X,))
acc_max = DTExt.rule_accuracy.(R_max, (X,), (fx,), (target,))
p1 = deepcopy(plt)
function plot_surr!(plt)
    for (i, rule) in enumerate(R_max)
        ubx, uby = minimum([rule[1][2], maximum(X[1, :])]),
        minimum([rule[2][2], maximum(X[2, :])])
        lbx, lby = maximum([rule[1][1], minimum(X[1, :])]),
        maximum([rule[2][1], minimum(X[2, :])])
        _feas = round(feas_max[i]; digits=2)
        _n = Int(round(feas_max[i] * n; digits=2))
        _acc = round(acc_max[i]; digits=2)
        @info "Rectangle R$i with feasibility $(_feas) (n≈$(_n)) and accuracy $(_acc)"
        lab = "R$i (ρ̂=$(_feas), τ̂=$(_acc))"
        plot!(plt, rectangle(ubx-lbx,uby-lby,lbx,lby), opacity=.5, color=i+2, label=lab, palette=col_pal)
    end
end
plot_surr!(p1)
p1
[ Info: Rectangle R1 with feasibility 0.22 (n≈1094) and accuracy 0.98
[ Info: Rectangle R2 with feasibility 0.07 (n≈335) and accuracy 0.92
[ Info: Rectangle R3 with feasibility 0.08 (n≈383) and accuracy 0.98
[ Info: Rectangle R4 with feasibility 0.04 (n≈211) and accuracy 0.94