Reference
In this reference, you will find a detailed overview of the package API.
Reference guides are technical descriptions of the machinery and how to operate it. Reference material is information-oriented.
β DiΓ‘taxis
In other words, you come here because you want to take a very close look at the code π§.
Content
Exported functions
CounterfactualExplanations.RawOutputArrayType
β TypeRawOutputArrayType
A type union for the allowed type for the output array y
.
CounterfactualExplanations.RawTargetType
β TypeRawTargetType
A type union for the allowed types for the target
variable.
CounterfactualExplanations.flux_training_params
β Constantflux_training_params
The default training parameter for FluxModels
etc.
CounterfactualExplanations.AbstractConvergence
β TypeAn abstract type that serves as the base type for convergence objects.
CounterfactualExplanations.AbstractCounterfactualExplanation
β TypeBase type for counterfactual explanations.
CounterfactualExplanations.AbstractGenerator
β TypeAn abstract type that serves as the base type for counterfactual generators.
CounterfactualExplanations.AbstractModel
β TypeBase type for models.
CounterfactualExplanations.AbstractPenalty
β TypeAn abstract type for penalty functions.
CounterfactualExplanations.CounterfactualExplanation
β TypeA struct that collects all information relevant to a specific counterfactual explanation for a single individual.
CounterfactualExplanations.CounterfactualExplanation
β Methodfunction CounterfactualExplanation(;
x::AbstractArray,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractModel,
generator::Generators.AbstractGenerator,
num_counterfactuals::Int = 1,
initialization::Symbol = :add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)
Outer method to construct a CounterfactualExplanation
structure.
CounterfactualExplanations.EncodedOutputArrayType
β TypeEncodedOutputArrayType
Type of encoded output array.
CounterfactualExplanations.EncodedTargetType
β TypeEncodedTargetType
Type of encoded target variable.
CounterfactualExplanations.OutputEncoder
β TypeOutputEncoder
The OutputEncoder
takes a raw output array (y
) and encodes it.
CounterfactualExplanations.OutputEncoder
β Method(encoder::OutputEncoder)(ynew::RawTargetType)
When called on a new value ynew
, the OutputEncoder
encodes it based on the initial encoding.
CounterfactualExplanations.OutputEncoder
β Method(encoder::OutputEncoder)()
On call, the OutputEncoder
returns the encoded output array.
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(
x::Base.Iterators.Zip,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractModel,
generator::AbstractGenerator;
kwargs...,
)
Overloads the generate_counterfactual
method to accept a zip of factuals x
and return a vector of counterfactuals.
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(
x::Matrix,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractModel,
generator::AbstractGenerator;
num_counterfactuals::Int=1,
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
timeout::Union{Nothing,Real}=nothing,
)
The core function that is used to run counterfactual search for a given factual x
, target, counterfactual data, model and generator. Keywords can be used to specify the desired threshold for the predicted target class probability and the maximum number of iterations.
Arguments
x::Matrix
: Factual data point.target::RawTargetType
: Target class.data::CounterfactualData
: Counterfactual data.M::Models.AbstractModel
: Fitted model.generator::AbstractGenerator
: Generator.num_counterfactuals::Int=1
: Number of counterfactuals to generate for factual.initialization::Symbol=:add_perturbation
: Initialization method. By default, the initialization is done by adding a small random perturbation to the factual to achieve more robustness.convergence::Union{AbstractConvergence,Symbol}=:decision_threshold
: Convergence criterion. By default, the convergence is based on the decision threshold. Possible values are:decision_threshold
,:max_iter
,:generator_conditions
or a conrete convergence object (e.g.DecisionThresholdConvergence
).timeout::Union{Nothing,Int}=nothing
: Timeout in seconds.
Examples
Generic generator
julia> using CounterfactualExplanations
julia> using TaijaData
# Counteractual data and model:
julia> counterfactual_data = CounterfactualData(load_linearly_separable()...);
julia> M = fit_model(counterfactual_data, :Linear);
julia> target = 2;
julia> factual = 1;
julia> chosen = rand(findall(predict_label(M, counterfactual_data) .== factual));
julia> x = select_factual(counterfactual_data, chosen);
# Search:
julia> generator = Generators.GenericGenerator();
julia> ce = generate_counterfactual(x, target, counterfactual_data, M, generator);
julia> converged(ce.convergence, ce)
true
Broadcasting
The generate_counterfactual
method can also be broadcasted over a tuple containing an array. This allows for generating multiple counterfactuals in parallel.
julia> chosen = rand(findall(predict_label(M, counterfactual_data) .== factual), 5);
julia> xs = select_factual(counterfactual_data, chosen);
julia> ces = generate_counterfactual.(xs, target, counterfactual_data, M, generator);
julia> converged(ce.convergence, ce)
true
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(
x::Matrix,
target::RawTargetType,
data::DataPreprocessing.CounterfactualData,
M::Models.AbstractModel,
generator::Generators.GrowingSpheresGenerator;
num_counterfactuals::Int=1,
convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
decision_threshold=(1 / length(data.y_levels)), max_iter=1000
),
kwrgs...,
)
Overloads the generate_counterfactual
for the GrowingSpheresGenerator
generator.
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(x::Tuple{<:AbstractArray}, args...; kwargs...)
Overloads the generate_counterfactual
method to accept a tuple containing and array. This allows for broadcasting over Zip
iterators.
CounterfactualExplanations.generate_counterfactual
β Methodgenerate_counterfactual(
x::Vector{<:Matrix},
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractModel,
generator::AbstractGenerator;
kwargs...,
)
Overloads the generate_counterfactual
method to accept a vector of factuals x
and return a vector of counterfactuals.
CounterfactualExplanations.get_target_index
β Methodget_target_index(y_levels, target)
Utility that returns the index of target
in y_levels
.
CounterfactualExplanations.path
β Methodpath(ce::CounterfactualExplanation)
A convenience method that returns the entire counterfactual path.
CounterfactualExplanations.target_probs
β Functiontarget_probs(
ce::CounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Returns the predicted probability of the target class for x
. If x
is nothing
, the predicted probability corresponding to the counterfactual value is returned.
CounterfactualExplanations.terminated
β Methodterminated(ce::CounterfactualExplanation)
A convenience method that checks if the counterfactual search has terminated.
CounterfactualExplanations.total_steps
β Methodtotal_steps(ce::CounterfactualExplanation)
A convenience method that returns the total number of steps of the counterfactual search.
CounterfactualExplanations.Convergence.convergence_catalogue
β Constantconvergence_catalogue
A dictionary containing all convergence criteria.
CounterfactualExplanations.Convergence.DecisionThresholdConvergence
β TypeDecisionThresholdConvergence
Convergence criterion based on the target class probability threshold. The search stops when the target class probability exceeds the predefined threshold.
Fields
decision_threshold::AbstractFloat
: The predefined threshold for the target class probability.max_iter::Int
: The maximum number of iterations.min_success_rate::AbstractFloat
: The minimum success rate for the target class probability.
CounterfactualExplanations.Convergence.GeneratorConditionsConvergence
β TypeGeneratorConditionsConvergence
Convergence criterion for counterfactual explanations based on the generator conditions. The search stops when the gradients of the search objective are below a certain threshold and the generator conditions are satisfied.
Fields
decision_threshold::AbstractFloat
: The threshold for the decision probability.gradient_tol::AbstractFloat
: The tolerance for the gradients of the search objective.max_iter::Int
: The maximum number of iterations.min_success_rate::AbstractFloat
: The minimum success rate for the generator conditions (across counterfactuals).
CounterfactualExplanations.Convergence.GeneratorConditionsConvergence
β MethodGeneratorConditionsConvergence(; decision_threshold=0.5, gradient_tol=1e-2, max_iter=100, min_success_rate=0.75, y_levels=nothing)
Outer constructor for GeneratorConditionsConvergence
.
CounterfactualExplanations.Convergence.MaxIterConvergence
β TypeMaxIterConvergence
Convergence criterion based on the maximum number of iterations.
Fields
max_iter::Int
: The maximum number of iterations.
CounterfactualExplanations.Convergence.conditions_satisfied
β Methodconditions_satisfied(gen::AbstractGenerator, ce::AbstractCounterfactualExplanation)
This function is overloaded in the Generators
module to check whether the counterfactual search has converged with respect to generator conditions.
CounterfactualExplanations.Convergence.converged
β Functionconverged(
convergence::InvalidationRateConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is invalidation rate.
CounterfactualExplanations.Convergence.converged
β Functionconverged(
convergence::DecisionThresholdConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is the decision threshold.
CounterfactualExplanations.Convergence.converged
β Functionconverged(
convergence::MaxIterConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is maximum iterations. This means the counterfactual search will not terminate until the maximum number of iterations has been reached independently of the other convergence criteria.
CounterfactualExplanations.Convergence.converged
β Functionconverged(
convergence::GeneratorConditionsConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)
Checks if the counterfactual search has converged when the convergence criterion is generator_conditions.
CounterfactualExplanations.Convergence.converged
β Methodconverged(ce::AbstractCounterfactualExplanation)
Returns true
if the counterfactual explanation has converged.
CounterfactualExplanations.Convergence.get_convergence_type
β Methodget_convergence_type(convergence::AbstractConvergence)
Returns the convergence object.
CounterfactualExplanations.Convergence.get_convergence_type
β Methodget_convergence_type(convergence::Symbol)
Returns the convergence object from the dictionary of default convergence types.
CounterfactualExplanations.Convergence.invalidation_rate
β Methodinvalidation_rate(ce::AbstractCounterfactualExplanation)
Calculates the invalidation rate of a counterfactual explanation.
Arguments
ce::AbstractCounterfactualExplanation
: The counterfactual explanation to calculate the invalidation rate for.kwargs
: Additional keyword arguments to pass to the function.
Returns
The invalidation rate of the counterfactual explanation.
CounterfactualExplanations.Convergence.threshold_reached
β Functionthreshold_reached(ce::AbstractCounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)
Determines if the predefined threshold for the target class probability has been reached.
CounterfactualExplanations.Evaluation.default_measures
β ConstantThe default evaluation measures.
CounterfactualExplanations.Evaluation.Benchmark
β TypeA container for benchmarks of counterfactual explanations. Instead of subtyping DataFrame
, it contains a DataFrame
of evaluation measures (see this discussion for why we don't subtype DataFrame
directly).
CounterfactualExplanations.Evaluation.Benchmark
β Method(bmk::Benchmark)(; agg=mean)
Returns a DataFrame
containing evaluation measures aggregated by num_counterfactual
.
CounterfactualExplanations.Evaluation.benchmark
β Methodbenchmark(
data::CounterfactualData;
models::Dict{<:Any,<:Any}=standard_models_catalogue,
generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}=nothing,
measure::Union{Function,Vector{Function}}=default_measures,
n_individuals::Int=5,
suppress_training::Bool=false,
factual::Union{Nothing,RawTargetType}=nothing,
target::Union{Nothing,RawTargetType}=nothing,
store_ce::Bool=false,
parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
kwrgs...,
)
Runs the benchmarking exercise as follows:
- Randomly choose a
factual
andtarget
label unless specified. - If no pretrained
models
are provided, it is assumed that a dictionary of callable model objects is provided (by default using thestandard_models_catalogue
). - Each of these models is then trained on the data.
- For each model separately choose
n_individuals
randomly from the non-target (factual
) class. For each generator create a benchmark as inbenchmark(xs::Union{AbstractArray,Base.Iterators.Zip})
. - Finally, concatenate the results.
If vertical_splits
is specified to an integer, the computations are split vertically into vertical_splits
chunks. In this case, the results are stored in a temporary directory and concatenated afterwards.
CounterfactualExplanations.Evaluation.benchmark
β Methodbenchmark(
x::Union{AbstractArray,Base.Iterators.Zip},
target::RawTargetType,
data::CounterfactualData;
models::Dict{<:Any,<:AbstractModel},
generators::Dict{<:Any,<:AbstractGenerator},
measure::Union{Function,Vector{Function}}=default_measures,
xids::Union{Nothing,AbstractArray}=nothing,
dataname::Union{Nothing,Symbol,String}=nothing,
verbose::Bool=true,
store_ce::Bool=false,
parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
kwrgs...,
)
First generates counterfactual explanations for factual x
, the target
and data
using each of the provided models
and generators
. Then generates a Benchmark
for the vector of counterfactual explanations as in benchmark(counterfactual_explanations::Vector{CounterfactualExplanation})
.
CounterfactualExplanations.Evaluation.benchmark
β Methodbenchmark(
counterfactual_explanations::Vector{CounterfactualExplanation};
meta_data::Union{Nothing,<:Vector{<:Dict}}=nothing,
measure::Union{Function,Vector{Function}}=default_measures,
store_ce::Bool=false,
)
Generates a Benchmark
for a vector of counterfactual explanations. Optionally meta_data
describing each individual counterfactual explanation can be supplied. This should be a vector of dictionaries of the same length as the vector of counterfactuals. If no meta_data
is supplied, it will be automatically inferred. All measure
functions are applied to each counterfactual explanation. If store_ce=true
, the counterfactual explanations are stored in the benchmark.
CounterfactualExplanations.Evaluation.evaluate
β Functionevaluate(
ce::CounterfactualExplanation;
measure::Union{Function,Vector{Function}}=default_measures,
agg::Function=mean,
report_each::Bool=false,
output_format::Symbol=:Vector,
pivot_longer::Bool=true
)
Just computes evaluation measures
for the counterfactual explanation. By default, no meta data is reported. For report_meta=true
, meta data is automatically inferred, unless this overwritten by meta_data
. The optional meta_data
argument should be a vector of dictionaries of the same length as the vector of counterfactual explanations.
CounterfactualExplanations.Evaluation.faithfulness
β Methodfaithfulness(
ce::CounterfactualExplanation,
fun::typeof(Objectives.distance_from_target);
Ξ»::AbstractFloat=1.0,
kwrgs...,
)
Computes the faithfulness of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the model posterior through SGLD (see distance_from_posterior
).
CounterfactualExplanations.Evaluation.plausibility
β Methodplausibility(
ce::CounterfactualExplanation,
fun::typeof(Objectives.distance_from_target);
K=nothing,
kwrgs...,
)
Computes the plausibility of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the target distribution.
CounterfactualExplanations.Evaluation.plausibility
β Methodplausibility(
ce::CounterfactualExplanation,
fun::typeof(Objectives.distance_from_target);
K=nothing,
kwrgs...,
)
Computes the plausibility of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the target distribution.
CounterfactualExplanations.Evaluation.redundancy
β Methodredundancy(ce::CounterfactualExplanation)
Computes the feature redundancy: that is, the number of features that remain unchanged from their original, factual values.
CounterfactualExplanations.Evaluation.validity
β Methodvalidity(ce::CounterfactualExplanation; Ξ³=0.5)
Checks of the counterfactual search has been successful with respect to the probability threshold Ξ³
. In case multiple counterfactuals were generated, the function returns the proportion of successful counterfactuals.
CounterfactualExplanations.DataPreprocessing.CounterfactualData
β MethodCounterfactualData(
X::AbstractMatrix,
y::RawOutputArrayType;
mutability::Union{Vector{Symbol},Nothing}=nothing,
domain::Union{Any,Nothing}=nothing,
features_categorical::Union{Vector{Vector{Int}},Nothing}=nothing,
features_continuous::Union{Vector{Int},Nothing}=nothing,
input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer}=nothing,
)
This outer constructor method prepares features X
and labels y
to be used with the package. Mutability and domain constraints can be added for the features. The function also accepts arguments that specify which features are categorical and which are continues. These arguments are currently not used.
Examples
using CounterfactualExplanations.Data
x, y = toy_data_linear()
X = hcat(x...)
counterfactual_data = CounterfactualData(X,y')
CounterfactualExplanations.DataPreprocessing.CounterfactualData
β Methodfunction CounterfactualData(
X::Tables.MatrixTable,
y::RawOutputArrayType;
kwrgs...
)
Outer constructor method that accepts a Tables.MatrixTable
. By default, the indices of categorical and continuous features are automatically inferred the features' scitype
.
CounterfactualExplanations.DataPreprocessing.apply_domain_constraints
β Methodapply_domain_constraints(counterfactual_data::CounterfactualData, x::AbstractArray)
A subroutine that is used to apply the predetermined domain constraints.
CounterfactualExplanations.DataPreprocessing.fit_transformer!
β Methodfit_transformer!(
data::CounterfactualData,
input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer};
kwargs...,
)
Fit a transformer to the data in place.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(data::CounterfactualData, input_encoder::Nothing; kwargs...)
Fit a transformer to the data. This is a no-op if input_encoder
is Nothing
.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(
data::CounterfactualData,
input_encoder::Type{<:CausalInference.SCM};
kwargs...,
)
Fit a transformer to the data for a SCM
object.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(
data::CounterfactualData,
input_encoder::Type{GenerativeModels.AbstractGenerativeModel};
kwargs...,
)
Fit a transformer to the data for a GenerativeModels.AbstractGenerativeModel
object.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(
data::CounterfactualData,
input_encoder::Type{MultivariateStats.AbstractDimensionalityReduction};
kwargs...,
)
Fit a transformer to the data for a MultivariateStats.AbstractDimensionalityReduction
object.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(
data::CounterfactualData,
input_encoder::Type{StatsBase.AbstractDataTransform};
kwargs...,
)
Fit a transformer to the data for a StatsBase.AbstractDataTransform
object.
CounterfactualExplanations.DataPreprocessing.fit_transformer
β Methodfit_transformer(data::CounterfactualData, input_encoder::InputTransformer; kwargs...)
Fit a transformer to the data for an InputTransformer
object. This is a no-op.
CounterfactualExplanations.DataPreprocessing.select_factual
β Methodselect_factual(counterfactual_data::CounterfactualData, index::Int)
A convenience method that can be used to access the feature matrix.
CounterfactualExplanations.DataPreprocessing.select_factual
β Methodselect_factual(counterfactual_data::CounterfactualData, index::Union{Vector{Int},UnitRange{Int}})
A convenience method that can be used to access the feature matrix.
CounterfactualExplanations.DataPreprocessing.transformable_features
β Methodtransformable_features(counterfactual_data::CounterfactualData, input_encoder::Any)
By default, all continuous features are transformable. This function returns the indices of all continuous features.
CounterfactualExplanations.DataPreprocessing.transformable_features
β Methodtransformable_features(
counterfactual_data::CounterfactualData, input_encoder::Type{CausalInference.SCM}
)
Returns the indices of all features that have causal parents.
CounterfactualExplanations.DataPreprocessing.transformable_features
β Methodtransformable_features(
counterfactual_data::CounterfactualData, input_encoder::Type{ZScoreTransform}
)
Returns the indices of all continuous features that can be transformed. For constant features ZScoreTransform
returns NaN
.
CounterfactualExplanations.DataPreprocessing.transformable_features
β Methodtransformable_features(counterfactual_data::CounterfactualData)
Dispatches the transformable_features
function to the appropriate method based on the type of the dt
field.
CounterfactualExplanations.Models.all_models_catalogue
β Constantall_models_catalogue
A dictionary containing both differentiable and non-differentiable machine learning models.
CounterfactualExplanations.Models.standard_models_catalogue
β Constantstandard_models_catalogue
A dictionary containing all differentiable machine learning models.
CounterfactualExplanations.AbstractModel
β Method(model::AbstractModel)(X::AbstractArray)
When called on data x
, logits are returned.
CounterfactualExplanations.Models.DeepEnsemble
β MethodDeepEnsemble(model; likelihood::Symbol=:classification_binary)
An outer constructor for a deep ensemble model.
CounterfactualExplanations.Models.Linear
β MethodLinear(model; likelihood::Symbol=:classification_binary)
An outer constructor for a linear model.
CounterfactualExplanations.Models.MLP
β MethodMLP(model; likelihood::Symbol=:classification_binary)
An outer constructor for a multi-layer perceptron (MLP) model.
CounterfactualExplanations.Models.Model
β TypeModel <: AbstractModel
Constructor for all models.
CounterfactualExplanations.Models.Model
β MethodModel(model, type::AbstractFluxNN; likelihood::Symbol=:classification_binary)
Overloaded constructor for Flux models.
CounterfactualExplanations.Models.Model
β MethodModel(model, type::AbstractModelType; likelihood::Symbol=:classification_binary)
Outer constructor for Model
where the atomic model is defined and assumed to be pre-trained.
CounterfactualExplanations.Models.Model
β Method(M::Model)(data::CounterfactualData, type::DeepEnsemble; kwargs...)
Constructs a deep ensemble for the given data.
CounterfactualExplanations.Models.Model
β Method(M::Model)(data::CounterfactualData, type::Linear; kwargs...)
Constructs a model with one linear layer for the given data. If the output is binary, this corresponds to logistic regression, since model outputs are passed through the sigmoid function. If the output is multi-class, this corresponds to multinomial logistic regression, since model outputs are passed through the softmax function.
CounterfactualExplanations.Models.Model
β Method(M::Model)(data::CounterfactualData, type::MLP; kwargs...)
Constructs a multi-layer perceptron (MLP) for the given data.
CounterfactualExplanations.Models.Model
β Method(M::Model)(data::CounterfactualData; kwargs...)
Wrap model M
around the data in data
.
CounterfactualExplanations.Models.Model
β MethodModel(type::AbstractModelType; likelihood::Symbol=:classification_binary)
Outer constructor for Model
where the atomic model is not yet defined.
CounterfactualExplanations.Models.fit_model
β Functionfit_model(
counterfactual_data::CounterfactualData, model::Symbol=:MLP;
kwrgs...
)
Fits one of the available default models to the counterfactual_data
. The model
argument can be used to specify the desired model. The available values correspond to the keys of the all_models_catalogue
dictionary.
CounterfactualExplanations.Models.fit_model
β Methodfit_model(
counterfactual_data::CounterfactualData, type::AbstractModelType; kwrgs...
)
A wrapper function to fit a model to the counterfactual_data
for a given type
of model.
Arguments
counterfactual_data::CounterfactualData
: The data to be used for training the model.type::AbstractModelType
: The type of model to be trained, e.g.,MLP
,DecisionTreeModel
, etc.
Examples
julia> using CounterfactualExplanations
julia> using CounterfactualExplanations.Models
julia> using TaijaData
julia> data = CounterfactualData(load_linearly_separable()...);
julia> M = fit_model(data, Linear())
CounterfactualExplanations.Models.Model(Chain(Dense(2 => 2)), :classification_multi, CounterfactualExplanations.Models.Fitresult(Chain(Dense(2 => 2)), Dict{Any, Any}()), Linear())
CounterfactualExplanations.Models.logits
β Methodlogits(M::Model, X::AbstractArray)
Returns the logits of the model.
CounterfactualExplanations.Models.logits
β Methodlogits(M::Model, type::AbstractFluxNN, X::AbstractArray)
Overloads the logits
function for Flux models.
CounterfactualExplanations.Models.logits
β Methodlogits(M::Model, type::MLJModelType, X::AbstractArray)
Overloads the logits method for MLJ models.
CounterfactualExplanations.Models.logits
β Methodlogits(M::Model, type::DeepEnsemble, X::AbstractArray)
Overloads the logits
function for deep ensembles.
CounterfactualExplanations.Models.model_evaluation
β Methodmodel_evaluation(M::AbstractModel, test_data::CounterfactualData)
Helper function to compute F-Score for AbstractModel
on a (test) data set. By default, it computes the accuracy. Any other measure, e.g. from the StatisticalMeasures package, can be passed as an argument. Currently, only measures applicable to classification tasks are supported.
CounterfactualExplanations.Models.predict_label
β Methodpredict_label(M::AbstractModel, counterfactual_data::CounterfactualData, X::AbstractArray)
Returns the predicted output label for a given model M
, data set counterfactual_data
and input data X
.
CounterfactualExplanations.Models.predict_label
β Methodpredict_label(M::AbstractModel, counterfactual_data::CounterfactualData)
Returns the predicted output labels for all data points of data set counterfactual_data
for a given model M
.
CounterfactualExplanations.Models.predict_proba
β Methodpredict_proba(M::AbstractModel, counterfactual_data::CounterfactualData, X::Union{Nothing,AbstractArray})
Returns the predicted output probabilities for a given model M
, data set counterfactual_data
and input data X
.
CounterfactualExplanations.Models.probs
β Methodprobs(M::Model, X::AbstractArray)
Returns the probabilities of the model.
CounterfactualExplanations.Models.probs
β Methodprobs(M::Model, type::AbstractFluxNN, X::AbstractArray)
Overloads the probs
function for Flux models.
CounterfactualExplanations.Models.probs
β Methodprobs(
M::Model,
type::MLJModelType,
X::AbstractArray,
)
Overloads the probs
method for MLJ models.
Note for developers
Note that currently the underlying MLJ methods (reformat
, predict
) are incompatible with Zygote's autodiff. For differentiable MLJ models, the probs
` and logits
methods need to be overloaded.
CounterfactualExplanations.Models.probs
β Methodprobs(M::Model, type::DeepEnsemble, X::AbstractArray)
Overloads the probs
function for deep ensembles.
CounterfactualExplanations.Generators.generator_catalogue
β ConstantA dictionary containing the constructors of all available counterfactual generators.
CounterfactualExplanations.Generators.AbstractGradientBasedGenerator
β TypeAbstractGradientBasedGenerator
An abstract type that serves as the base type for gradient-based counterfactual generators.
CounterfactualExplanations.Generators.AbstractNonGradientBasedGenerator
β TypeAbstractNonGradientBasedGenerator
An abstract type that serves as the base type for non gradient-based counterfactual generators.
CounterfactualExplanations.Generators.FeatureTweakGenerator
β TypeFeature Tweak counterfactual generator class.
CounterfactualExplanations.Generators.FeatureTweakGenerator
β MethodFeatureTweakGenerator(; penalty::Union{Nothing,Function,Vector{Function}}=Objectives.distance_l2, Ο΅::AbstractFloat=0.1)
Constructs a new Feature Tweak Generator object.
Uses the L2-norm as the penalty to measure the distance between the counterfactual and the factual. According to the paper by Tolomei et al., another recommended choice for the penalty in addition to the L2-norm is the L0-norm. The L0-norm simply minimizes the number of features that are changed through the tweak.
Arguments
penalty::Penalty
: The penalty function to use for the generator. Defaults todistance_l2
.Ο΅::AbstractFloat
: The tolerance value for the feature tweaks. Described at length in Tolomei et al. (https://arxiv.org/pdf/1706.06691.pdf). Defaults to 0.1.
Returns
generator::FeatureTweakGenerator
: A non-gradient-based generator that can be used to generate counterfactuals using the feature tweak method.
CounterfactualExplanations.Generators.GradientBasedGenerator
β TypeBase class for gradient-based counterfactual generators.
CounterfactualExplanations.Generators.GradientBasedGenerator
β MethodGradientBasedGenerator(;
loss::Union{Nothing,Function}=nothing,
penalty::Penalty=nothing,
Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing,
latent_space::Bool::false,
opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
generative_model_params::NamedTuple=(;),
)
Default outer constructor for GradientBasedGenerator
.
Arguments
loss::Union{Nothing,Function}=nothing
: The loss function used by the model.penalty::Penalty=nothing
: A penalty function for the generator to penalize counterfactuals too far from the original point.Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing
: The weight of the penalty function.latent_space::Bool=false
: Whether to use the latent space of a generative model to generate counterfactuals.opt::Flux.Optimise.AbstractOptimiser=Flux.Descent()
: The optimizer to use for the generator.generative_model_params::NamedTuple
: The parameters of the generative model associated with the generator.
Returns
generator::GradientBasedGenerator
: A gradient-based counterfactual generator.
CounterfactualExplanations.Generators.GrowingSpheresGenerator
β TypeGrowing Spheres counterfactual generator class.
CounterfactualExplanations.Generators.GrowingSpheresGenerator
β MethodGrowingSpheresGenerator(; n::Int=100, Ξ·::Float64=0.1, kwargs...)
Constructs a new Growing Spheres Generator object.
CounterfactualExplanations.Generators.JSMADescent
β TypeAn optimisation rule that can be used to implement a Jacobian-based Saliency Map Attack.
CounterfactualExplanations.Generators.JSMADescent
β MethodOuter constructor for the JSMADescent
rule.
CounterfactualExplanations.Generators.CLUEGenerator
β MethodConstructor for CLUEGenerator
. For details, see Antoran et al. (2021).
CounterfactualExplanations.Generators.ClaPROARGenerator
β MethodConstructor for ClaPGenerator
. For details, see Altmeyer et al. (2023).
CounterfactualExplanations.Generators.DiCEGenerator
β MethodConstructor for DiCEGenerator
. For details, see Mothilal et al. (2020).
CounterfactualExplanations.Generators.ECCoGenerator
β MethodConstructor for ECCoGenerator
. This corresponds to the generator proposed in https://arxiv.org/abs/2312.10648, without the conformal set size penalty. For details, see Altmeyer et al. (2024).
CounterfactualExplanations.Generators.GenericGenerator
β MethodConstructor for GenericGenerator
.
CounterfactualExplanations.Generators.GravitationalGenerator
β MethodConstructor for GravitationalGenerator
. For details, see Altmeyer et al. (2023).
CounterfactualExplanations.Generators.GreedyGenerator
β MethodConstructor for GreedyGenerator
. For details, see Schut et al. (2021).
CounterfactualExplanations.Generators.ProbeGenerator
β MethodConstructor for ProbeGenerator
. For details, see Pawelczyk et al. (2022).
Warning
The ProbeGenerator
is currenlty not working adequately. In particular, gradients are not computed with respect to the Hinge loss term proposed in the paper. It is still possible, however, to use this generator to achieve a desired invalidation rate. See issue #376 for details.
CounterfactualExplanations.Generators.REVISEGenerator
β MethodConstructor for REVISEGenerator
. For details, see Joshi et al. (2019).
CounterfactualExplanations.Generators.WachterGenerator
β MethodConstructor for WachterGenerator
. For details, see Wachter et al. (2018).
CounterfactualExplanations.Generators.generate_perturbations
β Methodgenerate_perturbations(
generator::AbstractGenerator, ce::AbstractCounterfactualExplanation
)
The default method to generate feature perturbations for any generator.
CounterfactualExplanations.Generators.generate_perturbations
β Methodgenerate_perturbations(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
The default method to generate feature perturbations for gradient-based generators through simple gradient descent.
CounterfactualExplanations.Generators.@objective
β Macroobjective(generator, ex)
A macro that can be used to define the counterfactual search objective.
CounterfactualExplanations.Generators.@search_feature_space
β Macrosearch_feature_space(generator)
A simple macro that can be used to specify feature space search.
CounterfactualExplanations.Generators.@search_latent_space
β Macrosearch_latent_space(generator)
A simple macro that can be used to specify latent space search.
CounterfactualExplanations.Generators.@with_optimiser
β Macrowith_optimiser(generator, optimiser)
A simple macro that can be used to specify the optimiser to be used.
CounterfactualExplanations.Objectives.ddp_diversity
β Methodddp_diversity(
ce::AbstractCounterfactualExplanation;
perturbation_size=1e-5
)
Evaluates how diverse the counterfactuals are using a Determinantal Point Process (DDP).
CounterfactualExplanations.Objectives.distance
β Methoddistance(
ce::AbstractCounterfactualExplanation;
from::Union{Nothing,AbstractArray}=nothing,
agg=mean,
p::Real=1,
weights::Union{Nothing,AbstractArray}=nothing,
)
Computes the distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_l0
β Methoddistance_l0(ce::AbstractCounterfactualExplanation)
Computes the L0 distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_l1
β Methoddistance_l1(ce::AbstractCounterfactualExplanation)
Computes the L1 distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_l2
β Methoddistance_l2(ce::AbstractCounterfactualExplanation)
Computes the L2 (Euclidean) distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_linf
β Methoddistance_linf(ce::AbstractCounterfactualExplanation)
Computes the L-inf distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_mad
β Methoddistance_mad(ce::AbstractCounterfactualExplanation; agg=mean)
This is the distance measure proposed by Wachter et al. (2017).
CounterfactualExplanations.Objectives.hinge_loss
β Methodhinge_loss(ce::AbstractCounterfactualExplanation)
Calculates the hinge loss of a counterfactual explanation with InvalidationRateConvergence
.
Arguments
ce::AbstractCounterfactualExplanation
: The counterfactual explanation to calculate the hinge loss for.
Returns
The hinge loss of the counterfactual explanation.
CounterfactualExplanations.Objectives.predictive_entropy
β Methodpredictive_entropy(ce::AbstractCounterfactualExplanation; agg=Statistics.mean)
Computes the predictive entropy of the counterfactuals. Explained in https://arxiv.org/abs/1406.2541.
Flux.Losses.logitbinarycrossentropy
β MethodFlux.Losses.logitbinarycrossentropy(ce::AbstractCounterfactualExplanation)
Simply extends the logitbinarycrossentropy
method to work with objects of type AbstractCounterfactualExplanation
.
Flux.Losses.logitcrossentropy
β MethodFlux.Losses.logitcrossentropy(ce::AbstractCounterfactualExplanation)
Simply extends the logitcrossentropy
method to work with objects of type AbstractCounterfactualExplanation
.
Flux.Losses.mse
β MethodFlux.Losses.mse(ce::AbstractCounterfactualExplanation)
Simply extends the mse
method to work with objects of type AbstractCounterfactualExplanation
.
Internal functions
CounterfactualExplanations.CRE
β TypeCRE <: AbstractCounterfactualExplanation
A Counterfactual Rule Explanation (CRE) is a global explanation for a given target
, model M
, data
and generator
.
CounterfactualExplanations.CRE
β Method(cre::CRE)(x::AbstractArray)
Generates a local counterfactual point explanation for x
using the generator
.
CounterfactualExplanations.DecisionTreeModel
β TypeDecisionTreeModel
Concrete type for tree-based models from DecisionTree.jl
. Since DecisionTree.jl
has an MLJ interface, we subtype the MLJModelType
model type.
CounterfactualExplanations.FluxModelParams
β TypeFluxModelParams
Default MLP training parameters.
CounterfactualExplanations.JEM
β TypeJEM
Concrete type for joint-energy models from JointEnergyModels
. Since JointEnergyModels
has an MLJ interface, we subtype the MLJModelType
model type.
CounterfactualExplanations.LaplaceReduxModel
β TypeLaplaceReduxModel
Concrete type for neural networks with Laplace Approximation from the LaplaceRedux
package. Currently subtyping the AbstractFluxNN
model type, although this may be changed to MLJ in the future.
CounterfactualExplanations.NeuroTreeModel
β TypeNeuroTreeModel
Concrete type for differentiable tree-based models from NeuroTreeModels
. Since NeuroTreeModels
has an MLJ interface, we subtype the MLJModelType
model type.
CounterfactualExplanations.RandomForestModel
β TypeRandomForestModel
Concrete type for random forest model from DecisionTree.jl
. Since the DecisionTree
package has an MLJ interface, we subtype the MLJModelType
model type.
CounterfactualExplanations.Rule
β TypeRule
A Rule
is just a list of bounds for the different features. See also CRE
.
Base.Broadcast.broadcastable
β MethodTreat AbstractGenerator
as scalar when broadcasting.
Base.Broadcast.broadcastable
β MethodTreat AbstractModel
as scalar when broadcasting.
Base.Broadcast.broadcastable
β MethodTreat AbstractPenalty
as scalar when broadcasting.
CounterfactualExplanations.adjust_shape!
β Methodadjust_shape!(ce::CounterfactualExplanation)
A convenience method that adjusts the dimensions of the counterfactual state and related fields.
CounterfactualExplanations.adjust_shape
β Methodadjust_shape(
ce::CounterfactualExplanation,
x::AbstractArray
)
A convenience method that adjusts the dimensions of x
.
CounterfactualExplanations.already_in_target_class
β Methodalready_in_target_class(ce::CounterfactualExplanation)
Check if the factual is already in the target class.
CounterfactualExplanations.apply_domain_constraints!
β Methodapply_domain_constraints!(ce::CounterfactualExplanation)
Wrapper function that applies underlying domain constraints.
CounterfactualExplanations.apply_mutability
β Methodapply_mutability(
ce::CounterfactualExplanation,
Ξcounterfactual_state::AbstractArray,
)
A subroutine that applies mutability constraints to the proposed vector of feature perturbations.
CounterfactualExplanations.counterfactual
β Methodcounterfactual(ce::CounterfactualExplanation)
A convenience method that returns the counterfactual.
CounterfactualExplanations.counterfactual_label
β Methodcounterfactual_label(ce::CounterfactualExplanation)
A convenience method that returns the predicted label of the counterfactual.
CounterfactualExplanations.counterfactual_label_path
β Methodcounterfactual_label_path(ce::CounterfactualExplanation)
Returns the counterfactual labels for each step of the search.
CounterfactualExplanations.counterfactual_probability
β Functioncounterfactual_probability(ce::CounterfactualExplanation)
A convenience method that computes the class probabilities of the counterfactual.
CounterfactualExplanations.counterfactual_probability_path
β Methodcounterfactual_probability_path(ce::CounterfactualExplanation)
Returns the counterfactual probabilities for each step of the search.
CounterfactualExplanations.decode_array
β Methoddecode_array(
data::CounterfactualData,
dt::CausalInference.SCM,
x::AbstractArray,
)
Helper function to decode an array x
using a data transform dt::GenerativeModels.AbstractGenerativeModel
.
CounterfactualExplanations.decode_array
β Methoddecode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)
Helper function to decode an array x
using a data transform dt::GenerativeModels.AbstractGenerativeModel
.
CounterfactualExplanations.decode_array
β Methoddecode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)
Helper function to decode an array x
using a data transform dt::MultivariateStats.AbstractDimensionalityReduction
.
CounterfactualExplanations.decode_array
β Methoddecode_array(dt::Nothing, x::AbstractArray)
Helper function to decode an array x
using a data transform dt::Nothing
. This is a no-op.
CounterfactualExplanations.decode_array
β Methoddecode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)
Helper function to decode an array x
using a data transform dt::StatsBase.AbstractDataTransform
.
CounterfactualExplanations.decode_state
β Functionfunction decode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing, )
Applies all the applicable decoding functions:
- If applicable, map the state variable back from the latent space to the feature space.
- If and where applicable, inverse-transform features.
- Reconstruct all categorical encodings.
Finally, the decoded counterfactual is returned.
CounterfactualExplanations.decode_state!
β Functiondecode_state!(ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)
In-place version of decode_state
.
CounterfactualExplanations.encode_array
β Methodencode_array(data::CounterfactualData, dt::CausalInference.SCM, x::AbstractArray)
Helper function to encode an array x
using a data transform dt::CausalInference.SCM
. This is a no-op.
CounterfactualExplanations.encode_array
β Methodencode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)
Helper function to encode an array x
using a data transform dt::GenerativeModels.AbstractGenerativeModel
.
CounterfactualExplanations.encode_array
β Methodencode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)
Helper function to encode an array x
using a data transform dt::MultivariateStats.AbstractDimensionalityReduction
.
CounterfactualExplanations.encode_array
β Methodencode_array(dt::Nothing, x::AbstractArray)
Helper function to encode an array x
using a data transform dt::Nothing
. This is a no-op.
CounterfactualExplanations.encode_array
β Methodencode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)
Helper function to encode an array x
using a data transform dt::StatsBase.AbstractDataTransform
.
CounterfactualExplanations.encode_state
β Functionfunction encode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing} = nothing, )
Applies all required encodings to x
:
- If applicable, it maps
x
to the latent space learned by the generative model. - If and where applicable, it rescales features.
Finally, it returns the encoded state variable.
CounterfactualExplanations.encode_state!
β Functionencode_state!(ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)
In-place version of encode_state
.
CounterfactualExplanations.factual
β Methodfactual(ce::CounterfactualExplanation)
A convenience method to retrieve the factual x
.
CounterfactualExplanations.factual_label
β Methodfactual_label(ce::CounterfactualExplanation)
A convenience method to get the predicted label associated with the factual.
CounterfactualExplanations.factual_probability
β Methodfactual_probability(ce::CounterfactualExplanation)
A convenience method to compute the class probabilities of the factual.
CounterfactualExplanations.find_potential_neighbours
β Functionfind_potential_neighbors(ce::AbstractCounterfactualExplanation)
Finds potential neighbors for the selected factual data point.
CounterfactualExplanations.get_meta
β Methodget_meta(ce::CounterfactualExplanation)
Returns meta data for a counterfactual explanation.
CounterfactualExplanations.guess_likelihood
β Methodguess_likelihood(y::RawOutputArrayType)
Guess the likelihood based on the scientific type of the output array. Returns a symbol indicating the guessed likelihood and the scientific type of the output array.
CounterfactualExplanations.guess_loss
β Methodguess_loss(ce::CounterfactualExplanation)
Guesses the loss function to be used for the counterfactual search in case likelihood
field is specified for the AbstractModel
instance and no loss function was explicitly declared for AbstractGenerator
instance.
CounterfactualExplanations.initialize!
β Methodinitialize!(ce::CounterfactualExplanation)
Initializes the counterfactual explanation. This method is called by the constructor. It does the following:
- Creates a dictionary to store information about the search.
- Initializes the counterfactual state.
- Initializes the search path.
- Initializes the loss.
CounterfactualExplanations.initialize_state!
β Methodinitialize_state!(ce::CounterfactualExplanation)
Initializes the starting point for the factual(s) in-place.
CounterfactualExplanations.initialize_state
β Methodinitialize_state(ce::CounterfactualExplanation)
Initializes the starting point for the factual(s):
- If
ce.initialization
is set to:identity
or counterfactuals are searched in a latent space, then nothing is done. - If
ce.initialization
is set to:add_perturbation
, then a random perturbation is added to the factual following following Slack (2021): https://arxiv.org/abs/2106.02666. The authors show that this improves adversarial robustness.
CounterfactualExplanations.outdim
β Methodoutdim(ce::CounterfactualExplanation)
A convenience method that returns the output dimension of the predictive model.
CounterfactualExplanations.polynomial_decay
β Methodpolynomial_decay(a::Real, b::Real, decay::Real, t::Int)
Computes the polynomial decay function as in Welling et al. (2011): https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf.
CounterfactualExplanations.reset!
β Methodreset!(flux_training_params::FluxModelParams)
Restores the default parameter values.
CounterfactualExplanations.steps_exhausted
β Methodsteps_exhausted(ce::CounterfactualExplanation)
A convenience method that checks if the number of maximum iterations has been exhausted.
CounterfactualExplanations.target_probs_path
β Methodtarget_probs_path(ce::CounterfactualExplanation)
Returns the target probabilities for each step of the search.
CounterfactualExplanations.update!
β Methodupdate!(ce::CounterfactualExplanation)
An important subroutine that updates the counterfactual explanation. It takes a snapshot of the current counterfactual search state and passes it to the generator. Based on the current state the generator generates perturbations. Various constraints are then applied to the proposed vector of feature perturbations. Finally, the counterfactual search state is updated.
CounterfactualExplanations.Convergence.max_iter
β Methodmax_iter(conv::AbstractConvergence)
Returns the maximum number of iterations specified.
CounterfactualExplanations.Evaluation.distance_measures
β ConstantAll distance measures.
CounterfactualExplanations.Evaluation.EnergySampler
β TypeBase type that stores information relevant to energy-based posterior sampling from AbstractModel
.
CounterfactualExplanations.Evaluation.EnergySampler
β MethodEnergySampler(
model::AbstractModel,
πx::Distribution,
πy::Distribution,
input_size::Dims,
yidx::Int;
opt::Union{Nothing,AbstractSamplingRule}=nothing,
nsamples::Int=100,
niter_final::Int=1000,
ntransitions::Int=0,
opt_warmup::Union{Nothing,AbstractSamplingRule}=nothing,
niter::Int=20,
batch_size::Int=50,
prob_buffer::AbstractFloat=0.95,
kwargs...,
)
Constructor for EnergySampler
, which is used to sample from the posterior distribution of the model conditioned on y
.
Arguments
model::AbstractModel
: The model to be used for sampling.data::CounterfactualData
: The data to be used for sampling.y::Any
: The conditioning value.opt::AbstractSamplingRule=ImproperSGLD()
: The sampling rule to be used. By default,SGLD
is used witha = (2 / std(Uniform()) * std(πx)
andb = 1
andΞ³=0.9
.nsamples::Int=100
: The number of samples to include in the final empirical posterior distribution.niter_final::Int=1000
: The number of iterations for generating samples from the posterior distribution. Typically, this number will be larger than the number of iterations during PMC training.ntransitions::Int=0
: The number of transitions for (optionally) warming up the sampler. By default, this is set to 0 and the sampler is not warmed up. For valies larger than 0, the sampler is trained through PMC forniter
iterations andntransitions
transitions to build a buffer of samples. The buffer is used for posterior sampling.opt_warmup::Union{Nothing,AbstractSamplingRule}=nothing
: The sampling rule to be used for warm-up. By default,ImproperSGLD
is used withΞ± = (2 / std(Uniform()) * std(πx)
andΞ³ = 0.005Ξ±
.niter::Int=100
: The number of iterations for training the sampler through PMC.batch_size::Int=50
: The batch size for training the sampler.prob_buffer::AbstractFloat=0.5
: The probability of drawing samples from the replay buffer. Smaller values will result in more samples being drawn from the prior and typically lead to better mixing and diversity in the samples.kwargs...
: Additional keyword arguments to be passed on to the sampler and PMC.
Returns
EnergySampler
: An instance ofEnergySampler
.
CounterfactualExplanations.Evaluation.EnergySampler
β MethodEnergySampler(ce::CounterfactualExplanation; kwrgs...)
Overloads the EnergySampler
constructor to accept a CounterfactualExplanation
object.
Base.rand
β FunctionBase.rand(sampler::EnergySampler, n::Int=100; retrain=false)
Overloads the rand
method to randomly draw n
samples from EnergySampler
. If from_posterior
is true
, the samples are drawn from the posterior distribution. Otherwise, the samples are generated from the model conditioned on the target value using a single chain (see generate_posterior_samples
).
Arguments
sampler::EnergySampler
: TheEnergySampler
object to be used for sampling.n::Int=100
: The number of samples to draw.from_posterior::Bool=true
: Whether to draw samples from the posterior distribution.niter::Int=500
: The number of iterations for generating samples through Monte Carlo sampling (single chain).
Returns
AbstractArray
: The samples.
Base.vcat
β MethodBase.vcat(bmk1::Benchmark, bmk2::Benchmark)
Vertically concatenates two Benchmark
objects.
CounterfactualExplanations.Evaluation.compute_measure
β Methodcompute_measure(ce::CounterfactualExplanation, measure::Function, agg::Function)
Computes a single measure for a counterfactual explanation. The measure is applied to the counterfactual explanation ce
and aggregated using the aggregation function agg
.
CounterfactualExplanations.Evaluation.define_prior
β Methoddefine_prior(
data::CounterfactualData;
πx::Union{Nothing,Distribution}=nothing,
πy::Union{Nothing,Distribution}=nothing,
n_std::Int=3,
)
Defines the prior for the data. The space is defined as a uniform distribution with bounds defined by the mean and standard deviation of the data. The bounds are extended by n_std
standard deviations.
Arguments
data::CounterfactualData
: The data to be used for defining the prior sampling space.n_std::Int=3
: The number of standard deviations to extend the bounds.
Returns
Uniform
: The uniform distribution defining the prior sampling space.
CounterfactualExplanations.Evaluation.distance_from_posterior
β Methoddistance_from_posterior(ce::AbstractCounterfactualExplanation)
Computes the distance from the counterfactual to generated conditional samples. The distance is computed as the mean distance from the counterfactual to the samples drawn from the posterior distribution of the model. By default, the cosine distance is used.
Arguments
ce::AbstractCounterfactualExplanation
: The counterfactual explanation object.nsamples::Int=1000
: The number of samples to draw.from_posterior::Bool=true
: Whether to draw samples from the posterior distribution.agg
: The aggregation function to use for computing the distance.choose_lowest_energy::Bool=true
: Whether to choose the samples with the lowest energy.choose_random::Bool=false
: Whether to choose random samples.nmin::Int=25
: The minimum number of samples to choose.p::Int=1
: The norm to use for computing the distance.cosine::Bool=true
: Whether to use the cosine distance.kwargs...
: Additional keyword arguments to be passed on to theEnergySampler
.
Returns
AbstractFloat
: The distance from the counterfactual to the samples.
CounterfactualExplanations.Evaluation.generate_posterior_samples
β Functiongenerate_posterior_samples(
e::EnergySampler, n::Int=1000; niter::Int=1000, kwargs...
)
Generates n
samples from the posterior distribution of the model conditioned on the target value y
. The samples are generated through (Persistent) Monte Carlo sampling using the EnergySampler
object. If the replay buffer is not empty, the initial samples are drawn from the buffer.
Note that by default the batch size of the sampler is set to round(Int, n / 100)
by default for sampling. This is to ensure that the samples are drawn independently from the posterior distribution. It also helps to avoid vanishing gradients.
The chain is run persistently until n
samples are generated. The number of transitions is set to ceil(Int, n / batch_size)
. Once the chain is run, the last n
samples are form the replay buffer are returned.
Arguments
e::EnergySampler
: TheEnergySampler
object to be used for sampling.n::Int=100
: The number of samples to generate.batch_size::Int=round(Int, n / 100)
: The batch size for sampling.niter::Int=1000
: The number of iterations for generating samples from the posterior distribution.kwargs...
: Additional keyword arguments to be passed on to the sampler.
Returns
AbstractArray
: The generated samples.
CounterfactualExplanations.Evaluation.get_lowest_energy_sample
β Methodget_lowest_energy_sample(sampler::EnergySampler; n::Int=5)
Chooses the samples with the lowest energy (i.e. highest probability) from EnergySampler
.
Arguments
sampler::EnergySampler
: TheEnergySampler
object to be used for sampling.n::Int=5
: The number of samples to choose.
Returns
AbstractArray
: The samples with the lowest energy.
CounterfactualExplanations.Evaluation.get_sampler!
β Methodget_sampler!(ce::AbstractCounterfactualExplanation; kwargs...)
Gets the EnergySampler
object from the counterfactual explanation. If the sampler is not found, it is constructed and stored in the counterfactual explanation object.
CounterfactualExplanations.Evaluation.to_dataframe
β Methodevaluate_dataframe(
ce::CounterfactualExplanation,
measure::Vector{Function},
agg::Function,
report_each::Bool,
pivot_longer::Bool,
store_ce::Bool,
)
Evaluates a counterfactual explanation and returns a dataframe of evaluation measures.
CounterfactualExplanations.Evaluation.validity_strict
β Methodvalidity_strict(ce::CounterfactualExplanation)
Checks if the counterfactual search has been strictly valid in the sense that it has converged with respect to the pre-specified target probability Ξ³
.
CounterfactualExplanations.Evaluation.warmup!
β Methodwarmup!(
e::EnergySampler,
y::Int;
niter::Int=20,
ntransitions::Int=100,
kwargs...,
)
Warms up the EnergySampler
to the underlying model for conditioning value y
. Specifically, this entails running PMC for niter
iterations and ntransitions
transitions to build a buffer of samples. The buffer is used for posterior sampling.
Arguments
e::EnergySampler
: TheEnergySampler
object to be trained.y::Int
: The conditioning value.opt::Union{Nothing,AbstractSamplingRule}
: The sampling rule to be used. By default,ImproperSGLD
is used withΞ± = 2 * std(Uniform(πx))
andΞ³ = 0.005Ξ±
.niter::Int=20
: The number of iterations for training the sampler through PMC.ntransitions::Int=100
: The number of transitions for training the sampler. In each transition, the sampler is updated with a mini-batch of data. Data is either drawn from the replay buffer or reinitialized from the prior.kwargs...
: Additional keyword arguments to be passed on to the sampler and PMC.
Returns
EnergySampler
: The trainedEnergySampler
.
CounterfactualExplanations.DataPreprocessing.InputTransformer
β TypeInputTransformer
Abstract type for data transformers. This can be any of the following:
StatsBase.AbstractDataTransform
: A data transformation object from theStatsBase
package.MultivariateStats.AbstractDimensionalityReduction
: A dimensionality reduction object from theMultivariateStats
package.GenerativeModels.AbstractGenerativeModel
: A generative model object from theGenerativeModels
module.
CounterfactualExplanations.DataPreprocessing.TypedInputTransformer
β TypeTypedInputTransformer
Abstract type for data transformers.
Base.Broadcast.broadcastable
β MethodTreat CounterfactualData
as scalar when broadcasting.
CounterfactualExplanations.DataPreprocessing._subset
β Method_subset(data::CounterfactualData, idx::Vector{Int})
Creates a subset of the data
.
CounterfactualExplanations.DataPreprocessing.convert_to_1d
β Methodconvert_to_1d(y::Matrix, y_levels::AbstractArray)
Helper function to convert a one-hot encoded matrix to a vector of labels. This is necessary because MLJ models require the labels to be represented as a vector, but the synthetic datasets in this package hold the labels in one-hot encoded form.
Arguments
y::Matrix
: The one-hot encoded matrix.y_levels::AbstractArray
: The levels of the categorical variable.
Returns
labels
: A vector of labels.
CounterfactualExplanations.DataPreprocessing.input_dim
β Methodinput_dim(counterfactual_data::CounterfactualData)
Helper function that returns the input dimension (number of features) of the data.
CounterfactualExplanations.DataPreprocessing.mutability_constraints
β Methodmutability_constraints(counterfactual_data::CounterfactualData)
A convenience function that returns the mutability constraints. If none were specified, it is assumed that all features are mutable in :both
directions.
CounterfactualExplanations.DataPreprocessing.outdim
β Methodoutdim(data::CounterfactualData)
Returns the number of output classes.
CounterfactualExplanations.DataPreprocessing.preprocess_data_for_mlj
β Methodpreprocess_data_for_mlj(data::CounterfactualData)
Helper function to preprocess data::CounterfactualData
for MLJ models.
Arguments
data::CounterfactualData
: The data to be preprocessed.
Returns
- (
df_x
,y
): A tuple containing the preprocessed data, withdf_x
being a DataFrame object andy
being a categorical vector.
Example
X, y = preprocessdatafor_mlj(data)
CounterfactualExplanations.DataPreprocessing.reconstruct_cat_encoding
β Methodreconstruct_cat_encoding(counterfactual_data::CounterfactualData, x::Vector)
Reconstruct the categorical encoding for a single instance.
CounterfactualExplanations.DataPreprocessing.subsample
β Methodsubsample(data::CounterfactualData, n::Int)
Helper function to randomly subsample data::CounterfactualData
.
CounterfactualExplanations.DataPreprocessing.train_test_split
β Methodtrain_test_split(data::CounterfactualData;test_size=0.2,keep_class_ratio=false)
Splits data into train and test split.
Arguments
data::CounterfactualData
: The data to be preprocessed.test_size=0.2
: Proportion of the data to be used for testing.keep_class_ratio=false
: Decides whether to sample equally from each class, or keep their relative size.
Returns
- (
train_data::CounterfactualData
,test_data::CounterfactualData
): A tuple containing the train and test splits.
Example
train, test = traintestsplit(data, testsize=0.1, keepclass_ratio=true)
CounterfactualExplanations.DataPreprocessing.unpack_data
β Methodunpack_data(data::CounterfactualData)
Helper function that unpacks data.
CounterfactualExplanations.Models.AbstractCustomDifferentiableModel
β TypeBase type for custom differentiable models.
CounterfactualExplanations.Models.AbstractDifferentiableModel
β TypeBase type for differentiable models.
CounterfactualExplanations.Models.AbstractDifferentiableModelType
β TypeAbstract types for differentiable models.
CounterfactualExplanations.Models.AbstractFluxModel
β TypeBase type for differentiable models written in Flux.
CounterfactualExplanations.Models.AbstractFluxNN
β TypeAbstract type for Flux models.
CounterfactualExplanations.Models.AbstractMLJModel
β TypeBase type for differentiable models from the MLJ library.
CounterfactualExplanations.Models.AbstractModelType
β Method(type::AbstractModelType)(model; likelihood::Symbol=:classification_binary)
Wrap model type
around the pre-trained model model
.
CounterfactualExplanations.Models.AbstractModelType
β Method(type::AbstractModelType)(data::CounterfactualData; kwargs...)
Wrap model type
around the data in data
. This is a convenience function to avoid having to construct a Model
object.
CounterfactualExplanations.Models.Differentiability
β TypeA base type for model differentiability.
CounterfactualExplanations.Models.Differentiability
β MethodDispatches on the type of model for the differentiability trait.
CounterfactualExplanations.Models.Fitresult
β TypeFitresult
A struct to hold the results of fitting a model.
Fields
fitresult
: The result of fitting the model to the data. This object should be callable on new data.other::Dict
: A dictionary to hold any other relevant information.
CounterfactualExplanations.Models.Fitresult
β Method(fitresult::Fitresult)(newdata::AbstractArray)
When called on new data, the Fitresult
object returns the result of calling the fitresult on new data.
CounterfactualExplanations.Models.Fitresult
β Method(fitresult::Fitresult)()
CounterfactualExplanations.Models.FluxNN
β TypeConcrete type for Flux models.
CounterfactualExplanations.Models.IsDifferentiable
β TypeStruct for models that are differentiable.
CounterfactualExplanations.Models.MLJModelType
β TypeAbstract type for MLJ models.
CounterfactualExplanations.Models.NonDifferentiable
β TypeBy default, models are assumed not to be differentiable.
CounterfactualExplanations.Models.binary_to_onehot
β Methodbinary_to_onehot(p)
Helper function to turn dummy-encoded variable into onehot-encoded variable.
CounterfactualExplanations.Models.build_ensemble
β Methodbuild_ensemble(K::Int;kw=(input_dim=2,n_hidden=32,output_dim=1))
Helper function that builds an ensemble of K
models.
CounterfactualExplanations.Models.build_mlp
β Methodbuild_mlp()
Helper function to build simple MLP.
Examples
nn = build_mlp()
CounterfactualExplanations.Models.data_loader
β Methoddata_loader(data::CounterfactualData)
Prepares counterfactual data for training in Flux.
CounterfactualExplanations.Models.forward!
β Methodforward!(model::Flux.Chain, data; loss::Symbol, opt::Symbol, n_epochs::Int=10, model_name="MLP")
Forward pass for training a Flux.Chain
model.
CounterfactualExplanations.Models.load_mnist_model
β Methodload_mnist_model(type::AbstractModelType)
Empty function to be overloaded for loading a pre-trained model for the AbstractModelType
model type.
CounterfactualExplanations.Models.load_mnist_model
β Methodload_mnist_model(type::DeepEnsemble)
Load a pre-trained deep ensemble model for the MNIST dataset.
CounterfactualExplanations.Models.load_mnist_model
β Methodload_mnist_model(type::MLP)
Load a pre-trained MLP model for the MNIST dataset.
CounterfactualExplanations.Models.load_mnist_vae
β Methodload_mnist_vae(; strong=true)
Load a pre-trained VAE model for the MNIST dataset.
CounterfactualExplanations.Models.train
β Methodtrain(M::Model, data::CounterfactualData)
Trains the model M
on the data in data
.
CounterfactualExplanations.Models.train
β Methodtrain(M::FluxModel, data::CounterfactualData; kwargs...)
Wrapper function to train Flux models.
CounterfactualExplanations.Models.train
β Methodtrain(
M::Model,
type::MLJModelType,
data::CounterfactualData,
)
Overloads the train
function for MLJ models.
CounterfactualExplanations.Models.train
β Methodtrain(M::Model, type::DeepEnsemble, data::CounterfactualData; kwargs...)
Overloads the train
function for deep ensembles.
CounterfactualExplanations.GenerativeModels.AbstractGMParams
β TypeBase type of generative model hyperparameter container.
CounterfactualExplanations.GenerativeModels.AbstractGenerativeModel
β TypeBase type for generative model.
CounterfactualExplanations.GenerativeModels.Encoder
β TypeEncoder
Constructs encoder part of VAE: a simple Flux neural network with one hidden layer and two linear output layers for the first two moments of the latent distribution.
CounterfactualExplanations.GenerativeModels.VAE
β TypeVAE <: AbstractGenerativeModel
Constructs the Variational Autoencoder. The VAE is a subtype of AbstractGenerativeModel
. Any (sub-)type of AbstractGenerativeModel
is accepted by latent space generators.
CounterfactualExplanations.GenerativeModels.VAE
β MethodVAE(input_dim;kws...)
Outer method for instantiating a VAE.
CounterfactualExplanations.GenerativeModels.VAEParams
β TypeVAEParams <: AbstractGMParams
The default VAE parameters describing both the encoder/decoder architecture and the training process.
Base.rand
β FunctionRandom.rand(encoder::Encoder, x, device=cpu)
Draws random samples from the latent distribution.
CounterfactualExplanations.GenerativeModels.Decoder
β MethodDecoder(input_dim::Int, latent_dim::Int, hidden_dim::Int; activation=relu)
The default decoder architecture is just a Flux Chain with one hidden layer and a linear output layer.
CounterfactualExplanations.GenerativeModels.decode
β Methoddecode(generative_model::VAE, x::AbstractArray)
Decodes an array x
using the VAE decoder.
CounterfactualExplanations.GenerativeModels.encode
β Methodencode(generative_model::VAE, x::AbstractArray)
Encodes an array x
using the VAE encoder. Specifically, it samples from the latent distribution. It does so by first passing x
through the encoder to obtain the mean and log-variance of the latent distribution. Then, it samples from the latent distribution using the reparameterization trick. See Random.rand(encoder::Encoder, x, device=cpu)
for more details.
CounterfactualExplanations.GenerativeModels.get_data
β Methodget_data(X::AbstractArray, y::AbstractArray, batch_size)
Preparing data for mini-batch training .
CounterfactualExplanations.GenerativeModels.get_data
β Methodget_data(X::AbstractArray, batch_size)
Preparing data for mini-batch training .
CounterfactualExplanations.GenerativeModels.reconstruct
β Functionreconstruct(generative_model::VAE, x, device=cpu)
Implements a full pass of some input x
through the VAE: x β¦ xΜ
.
CounterfactualExplanations.GenerativeModels.reparameterization_trick
β Functionreparameterization_trick(ΞΌ,logΟ,device=cpu)
Helper function that implements the reparameterization trick: z βΌ π©(ΞΌ,ΟΒ²) β z=ΞΌ + Ο β Ξ΅, Ξ΅ βΌ π©(0,I).
CounterfactualExplanations.Generators.Penalty
β TypeType union for acceptable argument types for the penalty
field of GradientBasedGenerator
.
CounterfactualExplanations.Generators.TCRExGenerator
β TypeT-CREx counterfactual generator class.
CounterfactualExplanations.Convergence.conditions_satisfied
β MethodConvergence.conditions_satisfied(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
The default method to check if the all conditions for convergence of the counterfactual search have been satisified for gradient-based generators. By default, gradient-based search is considered to have converged as soon as the proposed feature changes for all features are smaller than one percent of its standard deviation.
CounterfactualExplanations.Generators._replace_nans
β Function_replace_nans(Ξcounterfactual_state::AbstractArray, old_new::Pair=(NaN => 0))
Helper function to deal with exploding gradients. This is only a temporary fix and will be improved.
CounterfactualExplanations.Generators.feature_selection!
β Methodfeature_selection!(ce::AbstractCounterfactualExplanation)
Perform feature selection to find the dimension with the closest (but not equal) values between the ce.factual
(factual) and ce.counterfactual_state
(counterfactual) arrays.
Arguments
ce::AbstractCounterfactualExplanation
: An instance of theAbstractCounterfactualExplanation
type representing the counterfactual explanation.
Returns
nothing
The function iteratively modifies the ce.counterfactual_state
counterfactual array by updating its elements to match the corresponding elements in the ce.factual
factual array, one dimension at a time, until the predicted label of the modified ce.counterfactual_state
matches the predicted label of the ce.factual
array.
CounterfactualExplanations.Generators.find_closest_dimension
β Methodfind_closest_dimension(factual, counterfactual)
Find the dimension with the closest (but not equal) values between the factual
and counterfactual
arrays.
Arguments
factual
: The factual array.counterfactual
: The counterfactual array.
Returns
closest_dimension
: The index of the dimension with the closest values.
The function iterates over the indices of the factual
array and calculates the absolute difference between the corresponding elements in the factual
and counterfactual
arrays. It returns the index of the dimension with the smallest difference, excluding dimensions where the values in factual
and counterfactual
are equal.
CounterfactualExplanations.Generators.find_counterfactual
β Methodfind_counterfactual(model, factual_class, counterfactual_data, counterfactual_candidates)
Find the first counterfactual index by predicting labels.
Arguments
model
: The fitted model used for prediction.target_class
: Expected target class.counterfactual_data
: Data required for counterfactual generation.counterfactual_candidates
: The array of counterfactual candidates.
Returns
counterfactual
: The index of the first counterfactual found.
CounterfactualExplanations.Generators.growing_spheres_generation!
β Methodgrowing_spheres_generation(ce::AbstractCounterfactualExplanation)
Generate counterfactual candidates using the growing spheres generation algorithm.
Arguments
ce::AbstractCounterfactualExplanation
: An instance of theAbstractCounterfactualExplanation
type representing the counterfactual explanation.
Returns
nothing
This function applies the growing spheres generation algorithm to generate counterfactual candidates. It starts by generating random points uniformly on a sphere, gradually reducing the search space until no counterfactuals are found. Then it expands the search space until at least one counterfactual is found or the maximum number of iterations is reached.
The algorithm iteratively generates counterfactual candidates and predicts their labels using the model stored in ce.M
. It checks if any of the predicted labels are different from the factual class. The process of reducing the search space involves halving the search radius, while the process of expanding the search space involves increasing the search radius.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, ce::AbstractCounterfactualExplanation)
Dispatches to the appropriate complexity function for any generator.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, penalty::Function, ce::AbstractCounterfactualExplanation)
Overloads the h
function for the case where a single penalty function is provided.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, penalty::Nothing, ce::AbstractCounterfactualExplanation)
Overloads the h
function for the case where no penalty is provided.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)
Overloads the h
function for the case where a single penalty function is provided with additional keyword arguments.
CounterfactualExplanations.Generators.h
β Methodh(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)
Overloads the h
function for the case where a single penalty function is provided with additional keyword arguments.
CounterfactualExplanations.Generators.hyper_sphere_coordinates
β Methodhyper_sphere_coordinates(n_search_samples::Int, instance::Vector{Float64}, low::Int, high::Int; p_norm::Int=2)
Generates candidate counterfactuals using the growing spheres method based on hyper-sphere coordinates.
The implementation follows the Random Point Picking over a sphere algorithm described in the paper: "Learning Counterfactual Explanations for Tabular Data" by Pawelczyk, Broelemann & Kascneci (2020), presented at The Web Conference 2020 (WWW). It ensures that points are sampled uniformly at random using insights from: http://mathworld.wolfram.com/HyperspherePointPicking.html
The growing spheres method is originally proposed in the paper: "Comparison-based Inverse Classification for Interpretability in Machine Learning" by Thibaut Laugel et al (2018), presented at the International Conference on Information Processing and Management of Uncertainty in Knowledge-Based Systems (2018).
Arguments
n_search_samples::Int
: The number of search samples (int > 0).instance::AbstractArray
: The input point array.low::AbstractFloat
: The lower bound (float >= 0, l < h).high::AbstractFloat
: The upper bound (float >= 0, h > l).p_norm::Integer
: The norm parameter (int >= 1).
Returns
candidate_counterfactuals::Array
: An array of candidate counterfactuals.
CounterfactualExplanations.Generators.incompatible
β Methodincompatible(AbstractGenerator, AbstractCounterfactualExplanation)
Checks if the generator is incompatible with any of the additional specifications for the counterfactual explanations. By default, generators are assumed to be compatible.
CounterfactualExplanations.Generators.propose_state
β Methodpropose_state(
::Models.IsDifferentiable,
generator::AbstractGradientBasedGenerator,
ce::AbstractCounterfactualExplanation,
)
Proposes new state based on backpropagation for gradient-based generators and differentiable models.
CounterfactualExplanations.Generators.total_loss
β Methodtotal_loss(ce::AbstractCounterfactualExplanation)
Computes the total loss of a counterfactual explanation with respect to the search objective.
CounterfactualExplanations.Generators.β
β Methodβ(generator::AbstractGenerator, ce::AbstractCounterfactualExplanation)
Dispatches to the appropriate loss function for any generator.
CounterfactualExplanations.Generators.β
β Methodβ(generator::AbstractGenerator, loss::Function, ce::AbstractCounterfactualExplanation)
Overloads the β
function for the case where a single loss function is provided.
CounterfactualExplanations.Generators.β
β Methodβ(generator::AbstractGenerator, loss::Nothing, ce::AbstractCounterfactualExplanation)
Overloads the β
function for the case where no loss function is provided.
CounterfactualExplanations.Generators.βh
β Methodβh(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)
The default method to compute the gradient of the complexity penalty at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl
has gradient access.
If the penalty is not provided, it returns 0.0. By default, Zygote never works out the gradient for constants and instead returns 'nothing', so we need to add a manual step to override this behaviour. See here: https://discourse.julialang.org/t/zygote-gradient/26715.
CounterfactualExplanations.Generators.ββ
β Methodββ(
generator::AbstractGradientBasedGenerator,
ce::AbstractCounterfactualExplanation,
)
The default method to compute the gradient of the loss function at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl
has gradient access.
CounterfactualExplanations.Generators.β
β Methodβ(
generator::AbstractGradientBasedGenerator,
ce::AbstractCounterfactualExplanation,
)
The default method to compute the gradient of the counterfactual search objective for gradient-based generators. It simply computes the weighted sum over partial derivates. It assumes that Zygote.jl
has gradient access. If the counterfactual is being generated using Probe, the hinge loss is added to the gradient.
CounterfactualExplanations.Objectives.NeedsNeighbours
β TypePenalties that need access to neighbors in the target class.
CounterfactualExplanations.Objectives.NoPenaltyRequirements
β TypeBy default, penalties have no extra requirements.
CounterfactualExplanations.Objectives.PenaltyRequirements
β TypeA base type for a style of process.
CounterfactualExplanations.Objectives.PenaltyRequirements
β MethodThe distance_from_target
method needs neighbors in the target class.
CounterfactualExplanations.Objectives.cos_dist
β Methodcos_dist(x,y)
Computes the cosine distance between two vectors.
CounterfactualExplanations.Objectives.distance_from_target
β Methoddistance_from_target(
ce::AbstractCounterfactualExplanation;
K::Int=50
)
Computes the distance of the counterfactual from samples in the target main. If choose_randomly
is true
, the function will randomly sample K
neighbours from the target manifold. Otherwise, it will compute the pairwise distances and select the K
closest neighbours.
Arguments
ce::AbstractCounterfactualExplanation
: The counterfactual explanation.K::Int=50
: The number of neighbours to sample.choose_randomly::Bool=true
: Whether to sample neighbours randomly.kwrgs...
: Additional keyword arguments for the distance function.
Returns
Ξ::AbstractFloat
: The distance from the counterfactual to the target manifold.
CounterfactualExplanations.Objectives.energy
β Methodenergy(M::AbstractModel, x::AbstractArray, t::Int)
Computes the energy of the model at a given state as in Altmeyer et al. (2024): https://scholar.google.com/scholar?cluster=3697701546144846732&hl=en&as_sdt=0,5.
CounterfactualExplanations.Objectives.energy_constraint
β Methodenergy_constraint(
ce::AbstractCounterfactualExplanation;
agg=mean,
reg_strength::AbstractFloat=0.0,
decay::AbstractFloat=0.9,
kwargs...,
)
Computes the energy constraint for the counterfactual explanation as in Altmeyer et al. (2024): https://scholar.google.com/scholar?cluster=3697701546144846732&hl=en&as_sdt=0,5. The energy constraint is a regularization term that penalizes the energy of the counterfactuals. The energy is computed as the negative logit of the target class.
Arguments
ce::AbstractCounterfactualExplanation
: The counterfactual explanation.agg::Function=mean
: The aggregation function (only applicable in casenum_counterfactuals > 1
). Default ismean
.reg_strength::AbstractFloat=0.0
: The regularization strength.decay::AbstractFloat=0.9
: The decay rate for the polynomial decay function (defaults to 0.9). Parametera
is set to1.0 / ce.generator.opt.eta
, such that the initial step size is equal to 1.0, not accounting forb
. Parameterb
is set toround(Int, max_steps / 20)
, wheremax_steps
is the maximum number of iterations.kwargs...
: Additional keyword arguments.
Returns
β::AbstractFloat
: The energy constraint.
CounterfactualExplanations.Objectives.model_loss_penalty
β Methodfunction model_loss_penalty(
ce::AbstractCounterfactualExplanation;
agg=mean
)
Additional penalty for ClaPROARGenerator.
CounterfactualExplanations.Objectives.needs_neighbours
β Methodneeds_neighbours(ce::AbstractCounterfactualExplanation)
Check if a counterfactual explanation needs access to neighbors in the target class.
CounterfactualExplanations.Objectives.needs_neighbours
β Methodneeds_neighbours(gen::AbstractGenerator)
Check if a generator needs access to neighbors in the target class.
Extensions
DecisionTreeExt.AtomicDecisionTree
β TypeType union for DecisionTree
decision tree classifiers and regressors.
DecisionTreeExt.AtomicRandomForest
β TypeType union for DecisionTree
random forest classifiers and regressors.
CounterfactualExplanations.DecisionTreeModel
β MethodCounterfactualExplanations.DecisionTreeModel(
model::AtomicDecisionTree; likelihood::Symbol=:classification_binary
)
Outer constructor for a decision trees.
CounterfactualExplanations.Generators.TCRExGenerator
β Method(generator::Generators.TCRExGenerator)(
target::RawTargetType,
data::DataPreprocessing.CounterfactualData,
M::Models.AbstractModel
)
Applies the Generators.TCRExGenerator
to a given target
and data
using the M
model. For details see Bewley et al. (2024) [arXiv, PMLR].
CounterfactualExplanations.Models.Model
β Method(M::Models.Model)(
data::CounterfactualData,
type::CounterfactualExplanations.DecisionTreeModel;
kwargs...,
)
Constructs a decision tree for the given data. This method is used internally when a decision-tree model is constructed to be trained from scratch (i.e. no pre-trained model is supplied by the user).
CounterfactualExplanations.Models.Model
β Method(M::Models.Model)(
data::CounterfactualData, type::CounterfactualExplanations.RandomForestModel; kwargs...
)
Constructs a random forest for the given data.
CounterfactualExplanations.RandomForestModel
β MethodCounterfactualExplanations.RandomForestModel(
model::AtomicRandomForest; likelihood::Symbol=:classification_binary
)
Outer constructor for random forests.
CounterfactualExplanations.Generators.incompatible
β MethodGenerators.incompatible(gen::FeatureTweakGenerator, ce::CounterfactualExplanation)
Overloads the incompatible
function for the FeatureTweakGenerator
.
CounterfactualExplanations.Generators.propose_state
β MethodGenerators.propose_state(
generator::Generators.FeatureTweakGenerator, ce::AbstractCounterfactualExplanation
)
Overloads the Generators.propose_state
method for the FeatureTweakGenerator
.
DecisionTreeExt.calculate_delta
β Methodcalculate_delta(ce::AbstractCounterfactualExplanation, penalty::Vector{Function})
Calculates the penalty for the proposed feature tweak.
Arguments
ce::AbstractCounterfactualExplanation
: The counterfactual explanation object.
Returns
delta::Float64
: The calculated penalty for the proposed feature tweak.
DecisionTreeExt.classify_prototypes
β Methodclassify_prototypes(prototypes, rule_assignments, bounds)
Builds the second tree model using the given prototypes
as inputs and their corresponding rule_assignments
as labels. Split thresholds are restricted to the bounds
, which can be computed using partition_bounds(rules)
. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.cre
β Methodcre(rules, x, X)
Computes the counterfactual rule explanations (CRE) for a given point $x$ and a set of $rules$, where the $rules$ correspond to the set of maximal-valid rules for some given target. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.esatisfactory_instance
β Methodesatisfactory_instance(generator::FeatureTweakGenerator, x::AbstractArray, paths::Dict{String, Dict{String, Any}})
Returns an epsilon-satisfactory counterfactual for x
based on the paths provided.
Arguments
generator::FeatureTweakGenerator
: The feature tweak generator.x::AbstractArray
: The factual instance.paths::Dict{String, Dict{String, Any}}
: A list of paths to the leaves of the tree to be used for tweaking the feature.
Returns
esatisfactory::AbstractArray
: The epsilon-satisfactory instance.
Example
esatisfactory = esatisfactory_instance(generator, x, paths) # returns an epsilon-satisfactory counterfactual for x
based on the paths provided
DecisionTreeExt.extract_leaf_rules
β Methodextract_leaf_rules(root::DT.Root)
Extracts leaf decision rules (i.e. hyperrectangles) from a decision tree (root
). For a decision tree with $L$ leaves this results in $L$ hyperrectangles. The rules are returned as a vector of tuples containing 2-element tuples, where each 2-element tuple stores the lower and upper bound imposed by the given rule for a given feature. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.extract_leaf_rules
β Methodextract_leaf_rules(node::Union{DT.Leaf,DT.Node}, conditions::AbstractArray, decisions::AbstractArray)
See extract_leaf_rules(root::DT.Root)
for details.
DecisionTreeExt.extract_rules
β Methodextract_rules(root::DT.Root)
Extracts decision rules (i.e. hyperrectangles) from a decision tree (root
). For a decision tree with $L$ leaves this results in $2L-1$ hyperrectangles. The rules are returned as a vector of vectors of 2-element tuples, where each tuple stores the lower and upper bound imposed by the given rule for a given feature. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.extract_rules
β Methodextract_rules(node::DT.Node, conditions::AbstractArray)
DecisionTreeExt.get_individual_classifiers
β Methodget_individual_classifiers(M::Model)
Returns the individual classifiers in the forest. If the input is a decision tree, the method returns the decision tree itself inside an array.
Arguments
M::Model
: The model selected by the user.model::CounterfactualExplanations.D
Returns
classifiers::AbstractArray
: An array of individual classifiers in the forest.
DecisionTreeExt.grow_surrogate
β Methodgrow_surrogate(
generator::Generators.TCRExGenerator, X::AbstractArray, yΜ::AbstractArray
)
Grows the tree-based surrogate model for the Generators.TCRExGenerator
. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.grow_surrogate
β Methodgrow_surrogate(
generator::Generators.TCRExGenerator, data::CounterfactualData, M::AbstractModel
)
Overloads the grow_surrogate
function to accept a CounterfactualData
and a AbstractModel
to grow a surrogate model. See grow_surrogate(generator::Generators.TCRExGenerator, X::AbstractArray, yΜ::AbstractArray)
.
DecisionTreeExt.induced_grid
β Methodinduced_grid(rules)
Computes the induced grid of the given rules. For details see Bewley et al. (2024) [arXiv, PMLR]..
DecisionTreeExt.issubrule
β Methodissubrule(rule, otherrule)
Checks if the rule
hyperrectangle is a subset of the otherrule
hyperrectangle. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.max_valid
β Methodmax_valid(rules, X, fx, target, Ο)
Returns the maximal-valid rules for a given target
and accuracy threshold Ο
. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.partition_bounds
β Methodpartition_bounds(rules, dim::Int)
Computes the set of (unique) bounds for each rule in rules
along the dim
-th dimension. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.partition_bounds
β Methodpartition_bounds(rules)
Computes the set of (unique) bounds for each rule in rules
and all dimensions. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.prototype
β Methodprototype(rule, X; pick_arbitrary::Bool=true)
Picks an arbitrary point $x^C \in X$ (i.e. prototype) from the subet of $X$ that is contained by rule $R_i$. If pick_arbitrary
is set to false, the prototype is instead computed as the average across all samples. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_accuracy
β Methodrule_accuracy(rule, X, fx, target)
Computes the accuracy of the rule on the data X
for predicted outputs fx
and the target
. Accuracy is defined as the fraction of points contained by the rule, for which predicted values match the target. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_changes
β Methodrule_changes(rule, x)
Computes the number of feature changes necessary for x
to be contained by rule $R_i$. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_contains
β Methodrule_contains(rule, X)
Returns the subet of X
that is contained by rule $R_i$. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_cost
β Methodrule_cost(rule, x, X)
Computes the cost for $x$ to be contained by rule $R_i$, where cost is defined as rule_changes(rule, x) - rule_feasibility(rule, X)
. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_feasibility
β Methodrule_feasibility(rule, X)
Computes the feasibility of a rule $R_i$ for a given dataset. Feasibility is defined as fraction of the data points that satisfy the rule. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.search_path
β Functionsearch_path(tree::Union{DT.Leaf, DT.Node}, target::RawTargetType, path::AbstractArray)
Return a path index list with the inequality symbols, thresholds and feature indices.
Arguments
tree::Union{DT.Leaf, DT.Node}
: The root node of a decision tree.target::RawTargetType
: The target class.path::AbstractArray
: A list containing the paths found thus far.
Returns
paths::AbstractArray
: A list of paths to the leaves of the tree to be used for tweaking the feature.
Example
paths = search_path(tree, target) # returns a list of paths to the leaves of the tree to be used for tweaking the feature
DecisionTreeExt.wrap_decision_tree
β Functionwrap_decision_tree(node::TreeNode, X, y)
Turns a custom decision tree into a DecisionTree.Root
object from the DecisionTree.jl package.
DecisionTreeExt.wrap_decision_tree
β Methodwrap_decision_tree(node::TreeNode)
CounterfactualExplanations.JEM
β MethodCounterfactualExplanations.JEM(
model::JointEnergyModels.JointEnergyClassifier; likelihood::Symbol=:classification_multi
)
Outer constructor for a neural network with Laplace Approximation from LaplaceRedux.jl
.
CounterfactualExplanations.Models.Model
β MethodModels.Model(model, type::CounterfactualExplanations.JEM; likelihood::Symbol=:classification_multi)
Overloaded constructor for Flux models.
CounterfactualExplanations.Models.Model
β Method(M::Model)(data::CounterfactualData, type::JEM; kwargs...)
Constructs a differentiable tree-based model for the given data.
CounterfactualExplanations.Models.load_mnist_model
β MethodModels.load_mnist_model(type::CounterfactualExplanations.JEM)
Overload for loading a pre-trained model for the JEM
model type.
CounterfactualExplanations.Models.logits
β MethodModels.logits(M::JEM, X::AbstractArray)
Calculates the logit scores output by the model M for the input data X.
Arguments
M::JEM
: The model selected by the user. Must be a model from the MLJ library.X::AbstractArray
: The feature vector for which the logit scores are calculated.
Returns
logits::Matrix
: A matrix of logits for each output class for each data point in X.
Example
logits = Models.logits(M, x) # calculates the logit scores for each output class for the data point x
CounterfactualExplanations.Models.probs
β MethodModels.probs(
M::Models.Model,
type::CounterfactualExplanations.JEM,
X::AbstractArray,
)
Overloads the Models.probs method for NeuroTree models.
CounterfactualExplanations.Models.train
β Methodtrain(M::JEM, data::CounterfactualData; kwargs...)
Fits the model M
to the data in the CounterfactualData
object. This method is not called by the user directly.
Arguments
M::JEM
: The wrapper for an JEM model.data::CounterfactualData
: TheCounterfactualData
object containing the data to be used for training the model.
Returns
M::JEM
: The fitted JEM model.
CounterfactualExplanations.LaplaceReduxModel
β MethodCounterfactualExplanations.LaplaceReduxModel(
model::LaplaceRedux.Laplace; likelihood::Symbol=:classification_binary
)
Outer constructor for a neural network with Laplace Approximation from LaplaceRedux.jl
.
CounterfactualExplanations.Models.Model
β Method(M::Model)(data::CounterfactualData, type::LaplaceReduxModel; kwargs...)
Constructs a differentiable tree-based model for the given data.
CounterfactualExplanations.Models.logits
β Methodlogits(M::LaplaceReduxModel, X::AbstractArray)
Predicts the logit scores for the input data X
using the model M
.
CounterfactualExplanations.Models.probs
β Methodprobs(M::LaplaceReduxModel, X::AbstractArray)
Predicts the probabilities of the classes for the input data X
using the model M
.
CounterfactualExplanations.Models.train
β Methodtrain(M::LaplaceReduxModel, data::CounterfactualData; kwargs...)
Fits the model M
to the data in the CounterfactualData
object. This method is not called by the user directly.
Arguments
M::LaplaceReduxModel
: The wrapper for an LaplaceReduxModel model.data::CounterfactualData
: TheCounterfactualData
object containing the data to be used for training the model.
Returns
M::LaplaceReduxModel
: The fitted LaplaceReduxModel model.
NeuroTreeExt.AtomicNeuroTree
β TypeType union for NeuroTree classifiers and regressors.
CounterfactualExplanations.Models.Model
β Method(M::Model)(data::CounterfactualData, type::NeuroTreeModel; kwargs...)
Constructs a differentiable tree-based model for the given data.
CounterfactualExplanations.NeuroTreeModel
β MethodCounterfactualExplanations.NeuroTreeModel(
model::AtomicNeuroTree; likelihood::Symbol=:classification_binary
)
Outer constructor for a differentiable tree-based model from NeuroTreeModels.jl
.
CounterfactualExplanations.Models.logits
β MethodModels.logits(M::NeuroTreeModel, X::AbstractArray)
Calculates the logit scores output by the model M for the input data X.
Arguments
M::NeuroTreeModel
: The model selected by the user. Must be a model from the MLJ library.X::AbstractArray
: The feature vector for which the logit scores are calculated.
Returns
logits::Matrix
: A matrix of logits for each output class for each data point in X.
Example
logits = Models.logits(M, x) # calculates the logit scores for each output class for the data point x
CounterfactualExplanations.Models.probs
β MethodModels.probs(
M::Models.Model,
type::CounterfactualExplanations.NeuroTreeModel,
X::AbstractArray,
)
Overloads the probs method for NeuroTree models.
CounterfactualExplanations.Models.train
β Methodtrain(M::NeuroTreeModel, data::CounterfactualData; kwargs...)
Fits the model M
to the data in the CounterfactualData
object. This method is not called by the user directly.
Arguments
M::NeuroTreeModel
: The wrapper for an NeuroTree model.data::CounterfactualData
: TheCounterfactualData
object containing the data to be used for training the model.
Returns
M::NeuroTreeModel
: The fitted NeuroTree model.