When Causality meets Recourse
Counterfactual Explanations through Structural Causal Models
This post introduces a new tool in CounterfactualExplanations.jl, enhancing the package with causal reasoning to generate counterfactual explanations.
Introduction
In recent years, the need for interpretable and explainable AI has surged, particularly in high-stakes domains. Counterfactual explanations provide a means to understand how changes to input features could alter the outcomes of machine learning models. This blog post presents a new tool in the CounterfactualExplanations.jl package, developed during my JSoC (Julia Summer of Code) project, which incorporates causal reasoning into counterfactual generation.
Testimonial
This was an amazing experience, not just experience to contribute to two repositories simultaneously, but also to work with the mantainers of these repos. I learned a lot about the Julia language and the Julia community. This was possible because of the mentorship of Patrick Altmeyer (CounterfactualExplanations) and Moritz Schauer (CausalInference), who guided me throughout the project and are amazing researchers.
Project Overview
This project aimed to enhance the CounterfactualExplanations.jl package by infusing it with a robust mathematical foundation for minimal algorithmic recourse, based on the principles of causal reasoning (Karimi, Schölkopf, and Valera 2021).
Key Contributions
During the project, I contributed to two key repositories:
CounterfactualExplanations.jl: Developed a new tool for generating counterfactual explanations using causal information. This allows users to generate counterfactuals through causal interventions rather than minimal perturbations, ultimately providing more meaningful insights.
CausalInference.jl: Implemented a Structural Causal Model (SCM) structure that extracts information from data, laying the groundwork for the causal reasoning capabilities in CounterfactualExplanations.jl.
Theoretical Background
In this project, we developed a framework for the MINT Generator: a counterfactual generator based on the Recourse through Minimal Intervention (MINT) method proposed by Karimi, Schölkopf, and Valera (2021).
The MINT Generator incorporates causal reasoning to achieve algorithmic recourse through minimal interventions. In this sense, the main idea is that just perturbating a black-box model without taking into account the causal relations in the data can lead to misleading recommendations. Here we now shift to a perspective where every feature pertubation is an intervetion in the causal graph of the problem. Leveraging causal relationships, interventions on causal parents automatially lead to potentially useful changes in their causal children. The generator utilizes a Structural Causal Model(SCM) to encode the variables in a way that causal effects are propagated and uses a generic gradient-based generator to create the search path. This has the benefit that any existing gradient-based generator, such as ECCo (Altmeyer et al. 2024), Watcher (Wachter, Mittelstadt, and Russell 2017), DiCE (Mothilal, Sharma, and Tan 2020), and more, can be used with the MINT SCM encoder to generate counterfactual through causal interventions.
The MINT algorithm minimizes a loss function that combines the causal constraints of the SCM and the distance between the generated counterfactual and the original input. Since we want a gradient-based generator, we need to pass the constrained optimizaiton problem into an unconstrained one and we do this by using the Lagrangian. Initially, as defined in (Karimi, Schölkopf, and Valera 2021), we aim to aim to find the minimal cost set of actions \(A\) (in the form of structural interventions) that results in a counterfactual instance yielding the favorable output from \(h\),
\[ \begin{aligned} A^* \in \arg\min_A \text{cost}(A; \mathbf{x}_F)\\ \textrm{s.t.} \quad h(\mathbf{x}_{SCF}) \neq h(\mathbf{x}_F) \; \; \text{,}\\ \end{aligned} \]
where \(\mathbf{x}_F\) is the original input, \(\mathbf{x}_{SCF}\) is the counterfactual instance, and \(h\) is the black-box model. We use the \(\mathbf{x}_{SCF}\) terminology because the counterfactual is derived from the SCM,
\[ x_{SCF_i} = \begin{cases} x_{F_i} + \delta_i, & \text{if } i \in I \\ x_{F_i} + f_i(\text{pa}_{SCF_i}) - f_i(\text{pa}_{F_i}), & \text{if } i \notin I \; \; \text{,} \end{cases} \]
where \(I\) is the set of intervened upon variables, \(f_i\) is the function that generates the value of the variable \(i\) given its parents, and \(\text{pa}_{SCF_i}\) and \(\text{pa}_{F_i}\) are the parents of the variable \(i\) in the counterfactual and original instance, respectively. This closed formula for the decision variable \(\mathbf{x}_{SCF}\) is what makes possible to use a gradient-based generator, since with it the lagrangian is differentiable,
\[ \mathcal{L}(A ; \lambda) = \text{cost}(A; \mathbf{x}_F) + \lambda \left(h(\mathbf{x}_{SCF}) - h(\mathbf{x}_F) \right) \; \; \text{,} \]
or in simple terms and more standard, since \(\lambda\) is constant,
\[ \mathcal{L_{\texttt{MINT}}}(\mathbf{x}_{SCF}) = \lambda \text{cost}(\mathbf{x}_{SCF}; \mathbf{x}_F) + \text{yloss}(\mathbf{x}_{SCF},y^*) \; \; \text{,} \]
where \(y^*\) is clearly \(h(x_F)\) and \(\text{yloss}\) is :
\[ \text{yloss}(\mathbf{x}_{SCF}, y^*) = h \left(\left\{ x_{F_i} + \delta_i [i \in I] + \left(f_i(\text{pa}_{SCF_i}) - f_i(\text{pa}_{F_i}) \right) [i \notin I] \right\}_{i=1}^n \right) - y^* \; \; \text{.} \]
Implementation
As mentioned above, this project involved contributions to both CausalInference.jl and CounterfactualExplanations.jl. In this section, we will cover both of these. Before we begin, we load all necessary dependencies below:
Causal Inference
In terms of implementation, we need to capture the causal relations from the data, which is where CausalInference.jl comes in. However, before the project, the package did not have a SCM structure, in the sense that the methods just captured the topological Directed Acyclic Graph (DAG) that showed the causality governing the data. There was previously no way to transform graphs into structural causal models.
Consider the following synthetic data:
Using CausalInference.jl, we can use the ges
method for the causal discovery (Chickering 2003) and plot the resulting DAG Figure 1:
Code
┌ Warning: Only one thread available
└ @ CausalInference ~/.julia/packages/CausalInference/ozcj8/src/ges.jl:52
Given the DAG in Figure 1, our goal is to recover the equations that define the underlying causal relations. The SCM is the union of the DAG and these causal equations: formally, it can be represented as a tuple \((G, \mathbf{f})\), where \(G\) is the DAG and \(\mathbf{f}\) is the set of functions that generates the value of each variable given its parents.
Our solution for constructing the structural causal equations was to assume that the data was generated by a linear model, which in this simple synthetic example actually corresponds to the ground truth. For the DAG provided in the code example we derive
\[ v = \mathcal{b}_v \]
\[ x = \mathcal{a}_{v \to x} v + \mathcal{b}_x \]
\[ w = \mathcal{a}_{x \to w} x + \mathcal{b}_w \]
\[ z = \mathcal{a}_{v \to z} v+ \mathcal{a}_{w \to z} w + \mathcal{b}_z \]
\[ s = \mathcal{a}_{z \to s} z + \mathcal{b}_s \]
and that’s the tricky thing, as we can see these causal equations are different than the ones that generated the data, but they are the ones that respect the causal system obtained from the obtained DAG. Here \(\mathcal{b}_i\) and \(\mathcal{a}_{i \to j}\) are the intercept term and the coefficient obtained from the linear regression, respectively. To correctly solve the linear regression respecting the dependencies of the causal graph, we use topological_sort_by_dfs
from Graphs.jl
.
Now, with the SCM structure at hand, we see that the representation could be a struct containing the DAG and the coefficients/intercepts of the causal equations, which corresponds exactly the tuple \((G, \mathbf{f})\) that we defined. A technical difficulty is that since we aim for gradient-based counterfactual generation, we need to define a differentiable function that takes the SCM and applies the encoded causal relationships to all variables. That is where the causal_effects
matrix comes to the rescue.
Let the factual vector of features be denoted as:
\[ \mathbf{x}_F = \begin{bmatrix} x_{F_1} \\ x_{F_2} \\ x_{F_3} \\ \vdots \\ x_{F_n} \end{bmatrix} \]
Let the causal_effects
matrix be:
\[ \mathbf{C} = \begin{bmatrix} a_{11} & a_{12} & \cdots & a_{1n} & b_1 \\ a_{21} & a_{22} & \cdots & a_{2n} & b_2 \\ a_{31} & a_{32} & \cdots & a_{3n} & b_3 \\ \vdots & \vdots & \ddots & \vdots & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} & b_n \\ \end{bmatrix} \]
Here, \(a_{ij}\) represents the coefficient from the causal effect of \(x_{F_j}\) on \(x_{F_i}\), and \(b_i\) represents the intercept term for the variable \(x_{F_i}\).
The matrix multiplication of the causal_effects
matrix with the factual vector (excluding the bias term) is given by:
\[ \mathbf{C}_{:, 1:n} \cdot \mathbf{x}_F = \begin{bmatrix} a_{11} & a_{12} & \cdots & a_{1n} \\ a_{21} & a_{22} & \cdots & a_{2n} \\ a_{31} & a_{32} & \cdots & a_{3n} \\ \vdots & \vdots & \ddots & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} \end{bmatrix} \begin{bmatrix} x_{F_1} \\ x_{F_2} \\ x_{F_3} \\ \vdots \\ x_{F_n} \end{bmatrix} \]
Finally, we add the bias term:
\[ \mathbf{x}_{SCF} = \mathbf{C}_{:, 1:n} \cdot \mathbf{x}_F + \begin{bmatrix} b_1 \\ b_2 \\ b_3 \\ \vdots \\ b_n \end{bmatrix} \]
In expanded form:
\[ \mathbf{x}_{SCF_i} = a_{i1} x_{F_1} + a_{i2} x_{F_2} + \cdots + a_{in} x_{F_n} + b_i, \quad \forall i = 1, 2, \dots, n \]
This equation shows how each counterfactual variable \(x_{SCF_i}\) is generated as a linear combination of the factual inputs \(x_{F_j}\) based on the causal effects matrix, with an intercept term \(b_i\) added for each variable.
One can note that the orphan
nodes, that is, the nodes that do not have parents in the DAG, are going to be equal to the intercept term \(\mathcal{b}_\hat{o}\). The intuition behind this is that when we do the linear regression, variables that have no causal parents are just equal to the unconditional mean of the variable, i.e, we get \(x_{SCF_\hat{o}} = \mathbb{E}(x_\hat{o})\). Because of this, in some cases a better understanding of the regression is needed, so the residuals are also part of the SCM structure,
CounterfactualExplanations.jl
Next, we will dive to go into the optimization problem previously described in Section 2.2. Recall that we seek to minimize the Lagrangian function we defined where we now have a differentiable function. The standard way to implement generators in CounterfactualExplanations.jl
is to use autodifferentiation to solve this Lagrangian. The definition of \(\mathcal{L_{\texttt{MINT}}}\) above is just an unconstrained objective function, much like with any other gradient-based generator in the package, so the optimization is straightforward (see docs for more details on gradient-based generators).
A challenge was to find a way to pass the \(x_F\) into the \(x_{SCF}\). For the time being, we have decided to extend an existing feature of the package, namely InputTransformer
s: they can be used to transform features, for example, through standardization. All existing InputTransformer
s work under the premise of encoding features into some latent representation, searching counterfactuals in that latent space, and finally decoding latent features back into the original feature space. In some way, this is also what we are doing here: we are passing our factual to the “latent” causal space of the counterfactual.
Our first step is to create a new kind of InputTransformer
for the SCM:
Next, we need a way to actually apply train this transformer. This is done by “overloading” the fit_transformer
method,
Code
function fit_transformer(
data::CounterfactualData, input_encoder::Type{<:CausalInference.SCM}; kwargs...
)
t = Tables.table(transpose(data.X))
est_g, score = CausalInference.ges(t; penalty=1.0, parallel=true)
est_dag = CausalInference.pdag2dag!(est_g)
scm = CausalInference.estimate_equations(t, est_dag)
return scm
end
which takes an input data
set and then relies on CausalInference.jl for causal discovery.
We are getting there … but one implementation challenge is still left: how can we use the learned SCM during the counterfactual search?
Our idea was simple: during each gradient-step, just apply the SCM to all features of the counterfactual. Implementation-wise, this boiled down to overloading the decode_array
function, which handles the actual decoding step for all InputTransformers
:
Here we are! Using this approach, gradient computations explicitly take the causal graph into account. We can now rely on standard workflows for gradient-based generators to solve a different minimization problem that incorporate causal effects.
One piece that is still missing here is to implement a concrete generator type of the MINT Generator (#466). That will make it easier for users to use the MINT Generator in the same way as all of our other counterfactual generators. This step has been postponed, because it hinges on an a larger development task (#435).
Limitations and Future Work
Altough the range of lines of code was not tremendous, the hard work was. The merged code does not show every research and development I was guided during this time by the mentors. For example, initially we were trying a different approach to work with differentiation inside the CounterfactualExplanations.jl
package, where the code always broke 😅. But, without this obstacle, we could not have in mind a possible future work where the run_causal_effects
function could be more flexible. For example, as we said, the variables without causal parents are been assigned just to the unconditional mean, but in terms of counterfactuals theory, maybe using just the factual feature would be more realistic. But to do this we would need to work in the causal_effects_matrix
and this was part of my work creating transformable_features
for the SCM where we would probably need to use ignore_derivatives()
from Zygote.jl
.
Another direction of future work is that the current implementation of the MINT Generator is limited to linear causal relations. One could extend this to non-linear causal relations, such as those found in neural networks. Additionally, the we could shift the paradigm to use Bayesian Inference to generate the causal equations. This work would be a new extension in CausalInference.jl
.
Github PRs and Issues
In the following links, you can find the PRs and issues that were opened and closed during the project. They show some kind of history of the work developed:
Causal Inference:
- PR1 - SCM
- PR2 - causal effects matrix
- PR3 - version Julia register
- Issue1 - Retrieve equations CausalGraph
CounterfactualExplanations:
- PR1 - encondings.jl
- PR2 - Constrained Optimization
- PR3 - add MINT docs
- Issue1 - support for SCM
- Issue2 - Constrained Optimization
- Issue3 - Document MINT
And one that still open:
Usage
The MINT algorithm can be implemented using the GenericGenerator
and the SCM encoder, that we implement using CausalInference.jl
package. The following code snippet shows how to use the MINT algorithm to generate counterfactuals using any gradient-based generator:
Code
using CausalInference
using CounterfactualExplanations
using CounterfactualExplanations.DataPreprocessing: fit_transformer
using Tables
N = 2000
df = (
x = randn(N),
v = randn(N) .^ 2 + randn(N) * 0.25,
w = cos.(randn(N)) + randn(N) * 0.25,
z = randn(N) .^ 2 + cos.(randn(N)) + randn(N) * 0.25 + randn(N) * 0.25,
s = sin.(randn(N) .^ 2 + cos.(randn(N)) + randn(N) * 0.25 + randn(N) * 0.25) + randn(N) * 0.25
)
y_lab = rand(0:2, N)
counterfactual_data_scm = CounterfactualData(Tables.matrix(df; transpose=true), y_lab)
M = fit_model(counterfactual_data_scm, :Linear)
chosen = rand(findall(predict_label(M, counterfactual_data_scm) .== 1))
x = select_factual(counterfactual_data_scm, chosen)
data_scm = deepcopy(counterfactual_data_scm)
data_scm.input_encoder = fit_transformer(data_scm, CausalInference.SCM)
ce = generate_counterfactual(x, 2, data_scm, M, GenericGenerator(); initialization=:identity)
For further usage reference access the MINT official documentation.
Conclusion
During this project, I had the opportunity to contribute to both the XAI and Julia communities, where I implemented a SOTA method that used causal information to generate counterfactual explanations (Karimi, Schölkopf, and Valera 2021). It was an amazing experience to work with incredible mentors and the community. I would like to once again thank Patrick and Moritz for all their guidance, as well as Jacob Zelko and JuliaHUB for all their support. I learned a lot about the Julia language and the Julia community, and I witnessed firsthand the benefits of Open Source. In fact, it was Open Source that made it possible to contribute to two repositories simultaneously. This experience of working locally on CounterfactualExplanations.jl
and CausalInference.jl
was invaluable, as it allowed me to truly understand how things work under the hood.
I hope that the MINT Generator can be useful to the community and that the package continues to be improved in the future. Contributing to a more trustworthy Artificial Intelligence has always been a goal of mine, and even though my contribution may be small, I feel proud to have achieved something meaningful. My wish is to continue contributing to the responsible use of AI. I am very grateful for this opportunity, and I look forward to continuing to contribute to the community. 🚀🚀🚀
References
Citation
@online{luiz_franco2024,
author = {Luiz Franco, Jorge},
title = {When {Causality} Meets {Recourse}},
date = {2024-09-17},
url = {https://www.taija.org/blog/posts/causal-recourse/},
langid = {en}
}