When Causality meets Recourse

Counterfactual Explanations through Structural Causal Models

counterfactuals
explainable AI
causality
Julia

This post introduces a new tool in CounterfactualExplanations.jl, enhancing the package with causal reasoning to generate counterfactual explanations.

Author
Published

September 17, 2024

Introduction

In recent years, the need for interpretable and explainable AI has surged, particularly in high-stakes domains. Counterfactual explanations provide a means to understand how changes to input features could alter the outcomes of machine learning models. This blog post presents a new tool in the CounterfactualExplanations.jl package, developed during my JSoC (Julia Summer of Code) project, which incorporates causal reasoning into counterfactual generation.

Testimonial

This was an amazing experience, not just experience to contribute to two repositories simultaneously, but also to work with the mantainers of these repos. I learned a lot about the Julia language and the Julia community. This was possible because of the mentorship of Patrick Altmeyer (CounterfactualExplanations) and Moritz Schauer (CausalInference), who guided me throughout the project and are amazing researchers.

Project Overview

This project aimed to enhance the CounterfactualExplanations.jl package by infusing it with a robust mathematical foundation for minimal algorithmic recourse, based on the principles of causal reasoning (Karimi, Schölkopf, and Valera 2021).

Key Contributions

During the project, I contributed to two key repositories:

  1. CounterfactualExplanations.jl: Developed a new tool for generating counterfactual explanations using causal information. This allows users to generate counterfactuals through causal interventions rather than minimal perturbations, ultimately providing more meaningful insights.

  2. CausalInference.jl: Implemented a Structural Causal Model (SCM) structure that extracts information from data, laying the groundwork for the causal reasoning capabilities in CounterfactualExplanations.jl.

Theoretical Background

In this project, we developed a framework for the MINT Generator: a counterfactual generator based on the Recourse through Minimal Intervention (MINT) method proposed by Karimi, Schölkopf, and Valera (2021).

The MINT Generator incorporates causal reasoning to achieve algorithmic recourse through minimal interventions. In this sense, the main idea is that just perturbating a black-box model without taking into account the causal relations in the data can lead to misleading recommendations. Here we now shift to a perspective where every feature pertubation is an intervetion in the causal graph of the problem. Leveraging causal relationships, interventions on causal parents automatially lead to potentially useful changes in their causal children. The generator utilizes a Structural Causal Model(SCM) to encode the variables in a way that causal effects are propagated and uses a generic gradient-based generator to create the search path. This has the benefit that any existing gradient-based generator, such as ECCo (Altmeyer et al. 2024), Watcher (Wachter, Mittelstadt, and Russell 2017), DiCE (Mothilal, Sharma, and Tan 2020), and more, can be used with the MINT SCM encoder to generate counterfactual through causal interventions.

The MINT algorithm minimizes a loss function that combines the causal constraints of the SCM and the distance between the generated counterfactual and the original input. Since we want a gradient-based generator, we need to pass the constrained optimizaiton problem into an unconstrained one and we do this by using the Lagrangian. Initially, as defined in (Karimi, Schölkopf, and Valera 2021), we aim to aim to find the minimal cost set of actions \(A\) (in the form of structural interventions) that results in a counterfactual instance yielding the favorable output from \(h\),

\[ \begin{aligned} A^* \in \arg\min_A \text{cost}(A; \mathbf{x}_F)\\ \textrm{s.t.} \quad h(\mathbf{x}_{SCF}) \neq h(\mathbf{x}_F) \; \; \text{,}\\ \end{aligned} \]

where \(\mathbf{x}_F\) is the original input, \(\mathbf{x}_{SCF}\) is the counterfactual instance, and \(h\) is the black-box model. We use the \(\mathbf{x}_{SCF}\) terminology because the counterfactual is derived from the SCM,

\[ x_{SCF_i} = \begin{cases} x_{F_i} + \delta_i, & \text{if } i \in I \\ x_{F_i} + f_i(\text{pa}_{SCF_i}) - f_i(\text{pa}_{F_i}), & \text{if } i \notin I \; \; \text{,} \end{cases} \]

where \(I\) is the set of intervened upon variables, \(f_i\) is the function that generates the value of the variable \(i\) given its parents, and \(\text{pa}_{SCF_i}\) and \(\text{pa}_{F_i}\) are the parents of the variable \(i\) in the counterfactual and original instance, respectively. This closed formula for the decision variable \(\mathbf{x}_{SCF}\) is what makes possible to use a gradient-based generator, since with it the lagrangian is differentiable,

\[ \mathcal{L}(A ; \lambda) = \text{cost}(A; \mathbf{x}_F) + \lambda \left(h(\mathbf{x}_{SCF}) - h(\mathbf{x}_F) \right) \; \; \text{,} \]

or in simple terms and more standard, since \(\lambda\) is constant,

\[ \mathcal{L_{\texttt{MINT}}}(\mathbf{x}_{SCF}) = \lambda \text{cost}(\mathbf{x}_{SCF}; \mathbf{x}_F) + \text{yloss}(\mathbf{x}_{SCF},y^*) \; \; \text{,} \]

where \(y^*\) is clearly \(h(x_F)\) and \(\text{yloss}\) is :

\[ \text{yloss}(\mathbf{x}_{SCF}, y^*) = h \left(\left\{ x_{F_i} + \delta_i [i \in I] + \left(f_i(\text{pa}_{SCF_i}) - f_i(\text{pa}_{F_i}) \right) [i \notin I] \right\}_{i=1}^n \right) - y^* \; \; \text{.} \]

Implementation

As mentioned above, this project involved contributions to both CausalInference.jl and CounterfactualExplanations.jl. In this section, we will cover both of these. Before we begin, we load all necessary dependencies below:

Code
using CausalInference
using CounterfactualExplanations
using CounterfactualExplanations.GenerativeModels
using Graphs
using GraphRecipes
using MultivariateStats
using Plots
using Random
Random.seed!(1)
using StatsBase

Causal Inference

In terms of implementation, we need to capture the causal relations from the data, which is where CausalInference.jl comes in. However, before the project, the package did not have a SCM structure, in the sense that the methods just captured the topological Directed Acyclic Graph (DAG) that showed the causality governing the data. There was previously no way to transform graphs into structural causal models.

Consider the following synthetic data:

Code
N = 2000 # number of data points

x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25
df = (x=x, v=v, w=w, z=z, s=s)

Using CausalInference.jl, we can use the ges method for the causal discovery (Chickering 2003) and plot the resulting DAG Figure 1:

Code
est_g, score = ges(df; penalty=1.0, parallel=true)

plt = graphplot(pdag2dag!(est_g), names= [String(k) for k in keys(df)], size=(500,500), nodesize=0.1, fontsize=25)
savefig(plt, "www/intro.png")
display(plt)
┌ Warning: Only one thread available
└ @ CausalInference ~/.julia/packages/CausalInference/ozcj8/src/ges.jl:52
Figure 1: A simple example of a causal graph.

Given the DAG in Figure 1, our goal is to recover the equations that define the underlying causal relations. The SCM is the union of the DAG and these causal equations: formally, it can be represented as a tuple \((G, \mathbf{f})\), where \(G\) is the DAG and \(\mathbf{f}\) is the set of functions that generates the value of each variable given its parents.

Our solution for constructing the structural causal equations was to assume that the data was generated by a linear model, which in this simple synthetic example actually corresponds to the ground truth. For the DAG provided in the code example we derive

\[ v = \mathcal{b}_v \]

\[ x = \mathcal{a}_{v \to x} v + \mathcal{b}_x \]

\[ w = \mathcal{a}_{x \to w} x + \mathcal{b}_w \]

\[ z = \mathcal{a}_{v \to z} v+ \mathcal{a}_{w \to z} w + \mathcal{b}_z \]

\[ s = \mathcal{a}_{z \to s} z + \mathcal{b}_s \]

and that’s the tricky thing, as we can see these causal equations are different than the ones that generated the data, but they are the ones that respect the causal system obtained from the obtained DAG. Here \(\mathcal{b}_i\) and \(\mathcal{a}_{i \to j}\) are the intercept term and the coefficient obtained from the linear regression, respectively. To correctly solve the linear regression respecting the dependencies of the causal graph, we use topological_sort_by_dfs from Graphs.jl.

Now, with the SCM structure at hand, we see that the representation could be a struct containing the DAG and the coefficients/intercepts of the causal equations, which corresponds exactly the tuple \((G, \mathbf{f})\) that we defined. A technical difficulty is that since we aim for gradient-based counterfactual generation, we need to define a differentiable function that takes the SCM and applies the encoded causal relationships to all variables. That is where the causal_effects matrix comes to the rescue.

Let the factual vector of features be denoted as:

\[ \mathbf{x}_F = \begin{bmatrix} x_{F_1} \\ x_{F_2} \\ x_{F_3} \\ \vdots \\ x_{F_n} \end{bmatrix} \]

Let the causal_effects matrix be:

\[ \mathbf{C} = \begin{bmatrix} a_{11} & a_{12} & \cdots & a_{1n} & b_1 \\ a_{21} & a_{22} & \cdots & a_{2n} & b_2 \\ a_{31} & a_{32} & \cdots & a_{3n} & b_3 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} & b_n \\ \end{bmatrix} \]

Here, \(a_{ij}\) represents the coefficient from the causal effect of \(x_{F_j}\) on \(x_{F_i}\), and \(b_i\) represents the intercept term for the variable \(x_{F_i}\).

The matrix multiplication of the causal_effects matrix with the factual vector (excluding the bias term) is given by:

\[ \mathbf{C}_{:, 1:n} \cdot \mathbf{x}_F = \begin{bmatrix} a_{11} & a_{12} & \cdots & a_{1n} \\ a_{21} & a_{22} & \cdots & a_{2n} \\ a_{31} & a_{32} & \cdots & a_{3n} \\ \vdots & \vdots & \ddots & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} \end{bmatrix} \begin{bmatrix} x_{F_1} \\ x_{F_2} \\ x_{F_3} \\ \vdots \\ x_{F_n} \end{bmatrix} \]

Finally, we add the bias term:

\[ \mathbf{x}_{SCF} = \mathbf{C}_{:, 1:n} \cdot \mathbf{x}_F + \begin{bmatrix} b_1 \\ b_2 \\ b_3 \\ \vdots \\ b_n \end{bmatrix} \]

In expanded form:

\[ \mathbf{x}_{SCF_i} = a_{i1} x_{F_1} + a_{i2} x_{F_2} + \cdots + a_{in} x_{F_n} + b_i, \quad \forall i = 1, 2, \dots, n \]

This equation shows how each counterfactual variable \(x_{SCF_i}\) is generated as a linear combination of the factual inputs \(x_{F_j}\) based on the causal effects matrix, with an intercept term \(b_i\) added for each variable.

One can note that the orphan nodes, that is, the nodes that do not have parents in the DAG, are going to be equal to the intercept term \(\mathcal{b}_\hat{o}\). The intuition behind this is that when we do the linear regression, variables that have no causal parents are just equal to the unconditional mean of the variable, i.e, we get \(x_{SCF_\hat{o}} = \mathbb{E}(x_\hat{o})\). Because of this, in some cases a better understanding of the regression is needed, so the residuals are also part of the SCM structure,

Code
struct SCM
    variables::Vector{String}
    coefficients::Vector{Vector{Float64}}
    residuals::Vector{Vector{Float64}}
    dag::DiGraph
    causal_effects::Matrix{Float64}
end

CounterfactualExplanations.jl

Next, we will dive to go into the optimization problem previously described in Section 2.2. Recall that we seek to minimize the Lagrangian function we defined where we now have a differentiable function. The standard way to implement generators in CounterfactualExplanations.jl is to use autodifferentiation to solve this Lagrangian. The definition of \(\mathcal{L_{\texttt{MINT}}}\) above is just an unconstrained objective function, much like with any other gradient-based generator in the package, so the optimization is straightforward (see docs for more details on gradient-based generators).

A challenge was to find a way to pass the \(x_F\) into the \(x_{SCF}\). For the time being, we have decided to extend an existing feature of the package, namely InputTransformers: they can be used to transform features, for example, through standardization. All existing InputTransformers work under the premise of encoding features into some latent representation, searching counterfactuals in that latent space, and finally decoding latent features back into the original feature space. In some way, this is also what we are doing here: we are passing our factual to the “latent” causal space of the counterfactual.

Our first step is to create a new kind of InputTransformer for the SCM:

Code
const TypedInputTransformer = Union{
    Type{<:StatsBase.AbstractDataTransform},
    Type{<:MultivariateStats.AbstractDimensionalityReduction},
    Type{<:GenerativeModels.AbstractGenerativeModel},
    Type{<:CausalInference.SCM} # The SCM transfromer
}

Next, we need a way to actually apply train this transformer. This is done by “overloading” the fit_transformer method,

Code
function fit_transformer(
    data::CounterfactualData, input_encoder::Type{<:CausalInference.SCM}; kwargs...
)
    t = Tables.table(transpose(data.X))
    est_g, score = CausalInference.ges(t; penalty=1.0, parallel=true)
    est_dag = CausalInference.pdag2dag!(est_g)
    scm = CausalInference.estimate_equations(t, est_dag)
    return scm
end

which takes an input dataset and then relies on CausalInference.jl for causal discovery.

We are getting there … but one implementation challenge is still left: how can we use the learned SCM during the counterfactual search?

Our idea was simple: during each gradient-step, just apply the SCM to all features of the counterfactual. Implementation-wise, this boiled down to overloading the decode_array function, which handles the actual decoding step for all InputTransformers:

Code
function decode_array(data::CounterfactualData, dt::CausalInference.SCM, x::AbstractArray)
    return run_causal_effects(dt, x)
end

function run_causal_effects(scm::CausalInference.SCM, x::AbstractArray)
    return scm.causal_effects[:, 1:(end - 1)] * x + scm.causal_effects[:, end] # bias
end

Here we are! Using this approach, gradient computations explicitly take the causal graph into account. We can now rely on standard workflows for gradient-based generators to solve a different minimization problem that incorporate causal effects.

Concrete Generator Type

One piece that is still missing here is to implement a concrete generator type of the MINT Generator (#466). That will make it easier for users to use the MINT Generator in the same way as all of our other counterfactual generators. This step has been postponed, because it hinges on an a larger development task (#435).

Limitations and Future Work

Altough the range of lines of code was not tremendous, the hard work was. The merged code does not show every research and development I was guided during this time by the mentors. For example, initially we were trying a different approach to work with differentiation inside the CounterfactualExplanations.jl package, where the code always broke 😅. But, without this obstacle, we could not have in mind a possible future work where the run_causal_effects function could be more flexible. For example, as we said, the variables without causal parents are been assigned just to the unconditional mean, but in terms of counterfactuals theory, maybe using just the factual feature would be more realistic. But to do this we would need to work in the causal_effects_matrix and this was part of my work creating transformable_features for the SCM where we would probably need to use ignore_derivatives() from Zygote.jl.

Code
function transformable_features(
    counterfactual_data::CounterfactualData, input_encoder::Type{CausalInference.SCM}
)
    g = counterfactual_data.input_encoder.dag
    child_causal_nodes = [v for v in vertices(g) if indegree(g, v) >= 1]
    return child_causal_nodes
end

Another direction of future work is that the current implementation of the MINT Generator is limited to linear causal relations. One could extend this to non-linear causal relations, such as those found in neural networks. Additionally, the we could shift the paradigm to use Bayesian Inference to generate the causal equations. This work would be a new extension in CausalInference.jl.

Github PRs and Issues

In the following links, you can find the PRs and issues that were opened and closed during the project. They show some kind of history of the work developed:

Causal Inference:

CounterfactualExplanations:

And one that still open:

Usage

The MINT algorithm can be implemented using the GenericGenerator and the SCM encoder, that we implement using CausalInference.jl package. The following code snippet shows how to use the MINT algorithm to generate counterfactuals using any gradient-based generator:

Code
using CausalInference
using CounterfactualExplanations
using CounterfactualExplanations.DataPreprocessing: fit_transformer
using Tables

N = 2000
df = (
    x = randn(N), 
    v = randn(N) .^ 2 + randn(N) * 0.25, 
    w = cos.(randn(N)) + randn(N) * 0.25, 
    z = randn(N) .^ 2 + cos.(randn(N)) + randn(N) * 0.25 + randn(N) * 0.25, 
    s = sin.(randn(N) .^ 2 + cos.(randn(N)) + randn(N) * 0.25 + randn(N) * 0.25) + randn(N) * 0.25
)
y_lab = rand(0:2, N)
counterfactual_data_scm = CounterfactualData(Tables.matrix(df; transpose=true), y_lab)

M = fit_model(counterfactual_data_scm, :Linear)
chosen = rand(findall(predict_label(M, counterfactual_data_scm) .== 1))
x = select_factual(counterfactual_data_scm, chosen)

data_scm = deepcopy(counterfactual_data_scm)
data_scm.input_encoder = fit_transformer(data_scm, CausalInference.SCM)

ce = generate_counterfactual(x, 2, data_scm, M, GenericGenerator(); initialization=:identity)

For further usage reference access the MINT official documentation.

Conclusion

During this project, I had the opportunity to contribute to both the XAI and Julia communities, where I implemented a SOTA method that used causal information to generate counterfactual explanations (Karimi, Schölkopf, and Valera 2021). It was an amazing experience to work with incredible mentors and the community. I would like to once again thank Patrick and Moritz for all their guidance, as well as Jacob Zelko and JuliaHUB for all their support. I learned a lot about the Julia language and the Julia community, and I witnessed firsthand the benefits of Open Source. In fact, it was Open Source that made it possible to contribute to two repositories simultaneously. This experience of working locally on CounterfactualExplanations.jl and CausalInference.jl was invaluable, as it allowed me to truly understand how things work under the hood.

I hope that the MINT Generator can be useful to the community and that the package continues to be improved in the future. Contributing to a more trustworthy Artificial Intelligence has always been a goal of mine, and even though my contribution may be small, I feel proud to have achieved something meaningful. My wish is to continue contributing to the responsible use of AI. I am very grateful for this opportunity, and I look forward to continuing to contribute to the community. 🚀🚀🚀

References

Altmeyer, Patrick, Mojtaba Farmanbar, Arie van Deursen, and Cynthia CS Liem. 2024. “Faithful Model Explanations Through Energy-Constrained Conformal Counterfactuals.” In Proceedings of the AAAI Conference on Artificial Intelligence, 38:10829–37. 10.
Chickering, David Maxwell. 2003. “Optimal Structure Identification with Greedy Search.” J. Mach. Learn. Res. 3 (null): 507–54. https://doi.org/10.1162/153244303321897717.
Karimi, Amir-Hossein, Bernhard Schölkopf, and Isabel Valera. 2021. “Algorithmic Recourse: From Counterfactual Explanations to Interventions.” In Proceedings of the 2021 ACM Conference on Fairness, Accountability, and Transparency, 353–62. FAccT ’21. New York, NY, USA: Association for Computing Machinery. https://doi.org/10.1145/3442188.3445899.
Mothilal, Ramaravind K, Amit Sharma, and Chenhao Tan. 2020. “Explaining Machine Learning Classifiers Through Diverse Counterfactual Explanations.” In Proceedings of the 2020 Conference on Fairness, Accountability, and Transparency, 607–17. https://doi.org/10.1145/3351095.3372850.
Wachter, Sandra, Brent Mittelstadt, and Chris Russell. 2017. “Counterfactual Explanations Without Opening the Black Box: Automated Decisions and the GDPR.” Harv. JL & Tech. 31: 841. https://doi.org/10.2139/ssrn.3063289.

Citation

BibTeX citation:
@online{luiz_franco2024,
  author = {Luiz Franco, Jorge},
  title = {When {Causality} Meets {Recourse}},
  date = {2024-09-17},
  url = {https://www.taija.org/blog/posts/causal-recourse/},
  langid = {en}
}
For attribution, please cite this work as:
Luiz Franco, Jorge. 2024. “When Causality Meets Recourse.” September 17, 2024. https://www.taija.org/blog/posts/causal-recourse/.