Reference
In this reference, you will find a detailed overview of the package API.
Reference guides are technical descriptions of the machinery and how to operate it. Reference material is information-oriented.
β DiΓ‘taxis
In other words, you come here because you want to take a very close look at the code π§.
Content
Exported functions
CounterfactualExplanations.RawOutputArrayType β Type
RawOutputArrayTypeA type union for the allowed type for the output array y.
CounterfactualExplanations.RawTargetType β Type
RawTargetTypeA type union for the allowed types for the target variable.
CounterfactualExplanations.flux_training_params β Constant
flux_training_paramsThe default training parameter for FluxModels etc.
CounterfactualExplanations.AbstractConvergence β Type
An abstract type that serves as the base type for convergence objects.
CounterfactualExplanations.AbstractCounterfactualExplanation β Type
Base type for counterfactual explanations.
CounterfactualExplanations.AbstractGenerator β Type
An abstract type that serves as the base type for counterfactual generators.
CounterfactualExplanations.AbstractMeasure β Type
An abstract type that serves as the base type for measures. Objects of type AbstractMeasure need to be callable.
CounterfactualExplanations.AbstractModel β Type
Base type for models.
CounterfactualExplanations.AbstractPenalty β Type
An abstract type for penalty functions.
CounterfactualExplanations.CounterfactualExplanation β Type
A struct that collects all information relevant to a specific counterfactual explanation for a single individual.
CounterfactualExplanations.CounterfactualExplanation β Method
function CounterfactualExplanation(;
x::AbstractArray,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractModel,
generator::Generators.AbstractGenerator,
num_counterfactuals::Int = 1,
initialization::Symbol = :add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)Outer method to construct a CounterfactualExplanation structure.
CounterfactualExplanations.CounterfactualExplanation β Method
(ce::CounterfactualExplanation)()::FlattenedCECalling the ce::CounterfactualExplanation object results in a FlattenedCE instance, which is the flattened version of the original.
CounterfactualExplanations.EncodedOutputArrayType β Type
EncodedOutputArrayTypeType of encoded output array.
CounterfactualExplanations.EncodedTargetType β Type
EncodedTargetTypeType of encoded target variable.
CounterfactualExplanations.FlattenedCE β Type
FlattenedCE <: AbstractCounterfactualExplanationA flattened representation of a CounterfactualExplanation, containing only the factual, target, and counterfactual attributes. This can be useful for compact storage or transmission of explanations.
CounterfactualExplanations.OutputEncoder β Type
OutputEncoderThe OutputEncoder takes a raw output array (y) and encodes it.
CounterfactualExplanations.OutputEncoder β Method
(encoder::OutputEncoder)(ynew::RawTargetType)When called on a new value ynew, the OutputEncoder encodes it based on the initial encoding.
CounterfactualExplanations.OutputEncoder β Method
(encoder::OutputEncoder)()On call, the OutputEncoder returns the encoded output array.
CounterfactualExplanations.flatten β Method
flatten(ce::CounterfactualExplanation)Alias for (ce::CounterfactualExplanation)(). Converts a CounterfactualExplanation to its flattened form.
CounterfactualExplanations.generate_counterfactual β Method
generate_counterfactual(
x::Base.Iterators.Zip,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractModel,
generator::AbstractGenerator;
kwargs...,
)Overloads the generate_counterfactual method to accept a zip of factuals x and return a vector of counterfactuals.
CounterfactualExplanations.generate_counterfactual β Method
generate_counterfactual(
x::Matrix,
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractModel,
generator::AbstractGenerator;
num_counterfactuals::Int=1,
initialization::Symbol=:add_perturbation,
convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
timeout::Union{Nothing,Real}=nothing,
return_flattened::Bool=false,
)The core function that is used to run counterfactual search for a given factual x, target, counterfactual data, model and generator. Keywords can be used to specify the desired threshold for the predicted target class probability and the maximum number of iterations.
Arguments
x::Matrix: Factual data point.target::RawTargetType: Target class.data::CounterfactualData: Counterfactual data.M::Models.AbstractModel: Fitted model.generator::AbstractGenerator: Generator.num_counterfactuals::Int=1: Number of counterfactuals to generate for factual.initialization::Symbol=:add_perturbation: Initialization method. By default, the initialization is done by adding a small random perturbation to the factual to achieve more robustness.convergence::Union{AbstractConvergence,Symbol}=:decision_threshold: Convergence criterion. By default, the convergence is based on the decision threshold. Possible values are:decision_threshold,:max_iter,:generator_conditionsor a conrete convergence object (e.g.DecisionThresholdConvergence).timeout::Union{Nothing,Int}=nothing: Timeout in seconds.return_flattened::Bool: If true, the flattened CE is returned instead of a CE object.callback::Union{Nothing,Function}: An optional callback function that takes aCounterfactualExplanationas its only positional input.
Examples
Generic generator
julia> using CounterfactualExplanations
julia> using CounterfactualExplanations.Generators
julia> using CounterfactualExplanations.Models
julia> using TaijaData
# Counteractual data and model:
julia> counterfactual_data = CounterfactualData(load_linearly_separable()...);
julia> M = fit_model(counterfactual_data, :Linear);
julia> target = 2;
julia> factual = 1;
julia> chosen = rand(findall(Models.predict_label(M, counterfactual_data) .== factual));
julia> x = CounterfactualExplanations.select_factual(counterfactual_data, chosen);
# Search:
julia> generator = Generators.GenericGenerator();
julia> ce = generate_counterfactual(x, target, counterfactual_data, M, generator);
julia> CounterfactualExplanations.converged(ce.convergence, ce)
trueBroadcasting
The generate_counterfactual method can also be broadcasted over a tuple containing an array. This allows for generating multiple counterfactuals in parallel.
julia> using CounterfactualExplanations
julia> using CounterfactualExplanations.Generators
julia> using CounterfactualExplanations.Models
julia> using TaijaData
# Counteractual data and model:
julia> counterfactual_data = CounterfactualData(load_linearly_separable()...);
julia> M = fit_model(counterfactual_data, :Linear);
julia> target = 2;
julia> factual = 1;
julia> chosen = rand(findall(Models.predict_label(M, counterfactual_data) .== factual), 5);
julia> xs = CounterfactualExplanations.select_factual(counterfactual_data, chosen);
julia> generator = Generators.GenericGenerator();
julia> ces = generate_counterfactual.(xs, target, counterfactual_data, M, generator);CounterfactualExplanations.generate_counterfactual β Method
generate_counterfactual(x::Tuple{<:AbstractArray}, args...; kwargs...)Overloads the generate_counterfactual method to accept a tuple containing and array. This allows for broadcasting over Zip iterators.
CounterfactualExplanations.generate_counterfactual β Method
generate_counterfactual(
x::Vector{<:Matrix},
target::RawTargetType,
data::CounterfactualData,
M::Models.AbstractModel,
generator::AbstractGenerator;
kwargs...,
)Overloads the generate_counterfactual method to accept a vector of factuals x and return a vector of counterfactuals.
CounterfactualExplanations.get_global_ad_backend β Method
get_global_ad_backend()Get the currently set automatic differentiation backend.
Returns
- The global automatic differentiation backend as an instance of
DI.AutoZygote.
CounterfactualExplanations.get_target_index β Method
get_target_index(y_levels, target)Utility that returns the index of target in y_levels.
CounterfactualExplanations.num_counterfactuals β Method
num_counterfactuals(flat_ce::FlattenedCE)Extends the num_counterfactuals method to FlattenedCE.
CounterfactualExplanations.path β Method
path(ce::CounterfactualExplanation)A convenience method that returns the entire counterfactual path.
CounterfactualExplanations.set_global_ad_backend β Method
set_global_ad_backend(backend)Set the global automatic differentiation backend.
Arguments
backend: The new backend to set, which must be an instance ofDI.AbstractBackend.
CounterfactualExplanations.target_encoded β Method
target_encoded(ce::CounterfactualExplanation, data::CounterfactualData)Returns the encoded representation of ce.target.
CounterfactualExplanations.target_encoded β Method
target_encoded(flat_ce::FlattenedCE, data::CounterfactualData)Returns the encoded representation of flat_ce.target.
CounterfactualExplanations.target_probs β Function
target_probs(
ce::CounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)Returns the predicted probability of the target class for x. If x is nothing, the predicted probability corresponding to the counterfactual value is returned.
CounterfactualExplanations.terminated β Method
terminated(ce::CounterfactualExplanation)A convenience method that checks if the counterfactual search has terminated.
CounterfactualExplanations.total_steps β Method
total_steps(ce::AbstractCounterfactualExplanation)A convenience method that returns the total number of steps of the counterfactual search.
CounterfactualExplanations.Convergence.convergence_catalogue β Constant
convergence_catalogueA dictionary containing all convergence criteria.
CounterfactualExplanations.Convergence.DecisionThresholdConvergence β Type
DecisionThresholdConvergenceConvergence criterion based on the target class probability threshold. The search stops when the target class probability exceeds the predefined threshold.
Fields
decision_threshold::AbstractFloat: The predefined threshold for the target class probability.max_iter::Int: The maximum number of iterations.min_success_rate::AbstractFloat: The minimum success rate for the target class probability.
CounterfactualExplanations.Convergence.GeneratorConditionsConvergence β Type
GeneratorConditionsConvergenceConvergence criterion for counterfactual explanations based on the generator conditions. The search stops when the gradients of the search objective are below a certain threshold and the generator conditions are satisfied.
Fields
decision_threshold::AbstractFloat: The threshold for the decision probability.gradient_tol::AbstractFloat: The tolerance for the gradients of the search objective.max_iter::Int: The maximum number of iterations.min_success_rate::AbstractFloat: The minimum success rate for the generator conditions (across counterfactuals).
CounterfactualExplanations.Convergence.GeneratorConditionsConvergence β Method
GeneratorConditionsConvergence(; decision_threshold=0.5, gradient_tol=1e-2, max_iter=100, min_success_rate=0.75, y_levels=nothing)Outer constructor for GeneratorConditionsConvergence.
CounterfactualExplanations.Convergence.MaxIterConvergence β Type
MaxIterConvergenceConvergence criterion based on the maximum number of iterations.
Fields
max_iter::Int: The maximum number of iterations.
CounterfactualExplanations.Convergence.conditions_satisfied β Method
conditions_satisfied(gen::AbstractGenerator, ce::AbstractCounterfactualExplanation)This function is overloaded in the Generators module to check whether the counterfactual search has converged with respect to generator conditions.
CounterfactualExplanations.Convergence.converged β Function
converged(
convergence::MaxIterConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)Checks if the counterfactual search has converged when the convergence criterion is maximum iterations. This means the counterfactual search will not terminate until the maximum number of iterations has been reached independently of the other convergence criteria.
CounterfactualExplanations.Convergence.converged β Function
converged(
convergence::GeneratorConditionsConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)Checks if the counterfactual search has converged when the convergence criterion is generator_conditions.
CounterfactualExplanations.Convergence.converged β Function
converged(
convergence::DecisionThresholdConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)Checks if the counterfactual search has converged when the convergence criterion is the decision threshold.
CounterfactualExplanations.Convergence.converged β Function
converged(
convergence::InvalidationRateConvergence,
ce::AbstractCounterfactualExplanation,
x::Union{AbstractArray,Nothing}=nothing,
)Checks if the counterfactual search has converged when the convergence criterion is invalidation rate.
CounterfactualExplanations.Convergence.converged β Method
converged(ce::AbstractCounterfactualExplanation)Returns true if the counterfactual explanation has converged.
CounterfactualExplanations.Convergence.get_convergence_type β Method
get_convergence_type(convergence::AbstractConvergence)Returns the convergence object.
CounterfactualExplanations.Convergence.get_convergence_type β Method
get_convergence_type(convergence::Symbol)Returns the convergence object from the dictionary of default convergence types.
CounterfactualExplanations.Convergence.invalidation_rate β Method
invalidation_rate(ce::AbstractCounterfactualExplanation)Calculates the invalidation rate of a counterfactual explanation.
Arguments
ce::AbstractCounterfactualExplanation: The counterfactual explanation to calculate the invalidation rate for.kwargs: Additional keyword arguments to pass to the function.
Returns
The invalidation rate of the counterfactual explanation.
CounterfactualExplanations.Convergence.invalidation_rate β Method
invalidation_rate(ce::AbstractCounterfactualExplanation)Single-argument method for convenience.
CounterfactualExplanations.Convergence.threshold_reached β Function
threshold_reached(ce::AbstractCounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)Determines if the predefined threshold for the target class probability has been reached.
CounterfactualExplanations.Evaluation._serialization_state β Constant
Global serializer state (allows or disallows serialization).
CounterfactualExplanations.Evaluation.all_measures β Constant
All measures.
CounterfactualExplanations.Evaluation.default_measures β Constant
The default evaluation measures.
CounterfactualExplanations.Evaluation.distance_measures β Constant
All distance measures.
CounterfactualExplanations.Evaluation.plausibility_measures β Constant
Available plausibility measures.
CounterfactualExplanations.Evaluation.Benchmark β Type
A container for benchmarks of counterfactual explanations. Instead of subtyping DataFrame, it contains a DataFrame of evaluation measures (see this discussion for why we don't subtype DataFrame directly).
CounterfactualExplanations.Evaluation.Benchmark β Method
Benchmark(evaluation::DataFrames.DataFrame; counterfactuals=nothing)Constructs a Benchmark from an evaluation DataFrame.
CounterfactualExplanations.Evaluation.Benchmark β Method
(bmk::Benchmark)(; agg=mean)Returns a DataFrame containing evaluation measures aggregated by num_counterfactual.
CounterfactualExplanations.Evaluation.DefaultOutputIdentifier β Type
Default output identifier (no specific ID).
CounterfactualExplanations.Evaluation.ExplicitCETransformer β Type
The ExplicitCETransformer can be used to specify any arbitrary CE transformation.
CounterfactualExplanations.Evaluation.ExplicitOutputIdentifier β Type
And explicit output identifier that takes the string value of id.
CounterfactualExplanations.Evaluation.IdentityTransformer β Type
Default CE transformer that returns the input as is.
CounterfactualExplanations.Evaluation.MMD β Type
MMD{K<:KernelFunctions.Kernel} <: AbstractDivergenceMetricConcrete type for the Maximum Mean Discrepancy (MMD) metric.
CounterfactualExplanations.Evaluation.MMD β Method
(m::MMD)(
x::AbstractArray,
y::AbstractArray,
n::Int;
kwrgs...
)Computes the MMD between two datasets x and y, along with a p-value based on a null distribution of MMD values (unless m.compute_p=nothing) for a random subset of the data (of sample size n). The p-value is computed using a permutation test.
CounterfactualExplanations.Evaluation.MMD β Method
(m::MMD)(x::AbstractArray, y::AbstractArray)Computes the maximum mean discrepancy (MMD) between two datasets x and y. The MMD is a measure of the difference between two probability distributions. It is defined as the maximum value of the kernelized dot product between the two datasets. It is computed as the sum of average kernel values between columns (samples) of x and y, minus twice the average kernel value between columns (samples) of x and y. A larger MMD value indicates that the distributions are more different, while a value closer to zero suggests they are more similar. See also kernelsum.
CounterfactualExplanations.Evaluation.NullSerializer β Type
Null serializer (does not allow serialization).
CounterfactualExplanations.Evaluation.Serializer β Type
Standard serializer (allows serialization).
CounterfactualExplanations.Evaluation.benchmark β Method
benchmark(
data::CounterfactualData;
test_data::Union{Nothing,CounterfactualData}=nothing,
models::Dict{<:Any,<:Any}=standard_models_catalogue,
generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}=nothing,
measure::Union{Function,Vector{<:Function}}=default_measures,
n_individuals::Int=5,
n_runs::Int=1,
suppress_training::Bool=false,
factual::Union{Nothing,RawTargetType}=nothing,
target::Union{Nothing,RawTargetType}=nothing,
store_ce::Bool=false,
parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
dataname::Union{Nothing,Symbol,String}=nothing,
verbose::Bool=true,
vertical_splits::Union{Nothing,Int}=nothing,
storage_path::String=tempdir(),
kwrgs...,
)Benchmark a set of counterfactuals for a given data set and additional inputs.
Arguments
data::CounterfactualData: The dataset containing the factual and target labels.test_data::Union{Nothing,CounterfactualData}: Optional test data for evaluation. Defaults tonothing, in which casedatais used for evaluation.models::Dict{<:Any,<:Any}: A dictionary of model objects keyed by their names. Defaults tostandard_models_catalogue.generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}: Optional dictionary of generator functions keyed by their names. Defaults tonothing, in which case the wholegenerator_catalogueis used.measure::Union{Function,Vector{<:Function}}: The measure(s) to evaluate the counterfactuals against. Defaults todefault_measures.n_individuals::Int=5: Number of individuals to generate for each model and generator.n_runs::Int=1: Number of runs for each model and generator.suppress_training::Bool=false: Whether to suppress training of models during benchmarking. This is useful if models have already been trained.factual::Union{Nothing,RawTargetType}: Optional factual label. Defaults tonothing, in which case factual labels are randomly sampled from the dataset.target::Union{Nothing,RawTargetType}: Optional target label. Defaults tonothing, in which case target labels are randomly sampled from the dataset.store_ce::Bool=false: Whether to store theCounterfactualExplanationobjects for each counterfactual.parallelizer::Union{Nothing,AbstractParallelizer}=nothing: Parallelization strategy for generating and evaluating counterfactuals.dataname::Union{Nothing,Symbol,String}=nothing: Name of the dataset. Defaults tonothing.verbose::Bool=true: Whether to print verbose output during benchmarking.vertical_splits::Union{Nothing,Int}=nothing: Number of elements per vertical split for generating counterfactuals. Defaults tonothing. This can useful, if it is necessary to reduce peak memory usage, by decreasing the number of counterfactuals generated at once. Lower values lead to smaller batches and hence smaller peak load.storage_path::String=tempdir(): Path where interim results will be stored. Defaults totempdir().concatenate_output::Bool=true: Whether to collect output from each run and concatenate them into a single output file.
Returns
- A dictionary containing the benchmark results, including mean and standard deviation of the measures across all runs.
Benchmarking Procedure
Runs the benchmarking exercise as follows:
- Randomly choose a
factualandtargetlabel unless specified. - If no pretrained
modelsare provided, it is assumed that a dictionary of callable model objects is provided (by default using thestandard_models_catalogue). - Each of these models is then trained on the data.
- For each model separately choose
n_individualsrandomly from the non-target (factual) class. For each generator create a benchmark as inbenchmark(xs::Union{AbstractArray,Base.Iterators.Zip}). - Finally, concatenate the results.
If vertical_splits is specified to an integer, the computations are split vertically into chunks of vertical_splits each. In this case, the results are stored in a temporary directory and concatenated afterwards.
CounterfactualExplanations.Evaluation.benchmark β Method
benchmark(
x::Union{AbstractArray,Base.Iterators.Zip},
target::RawTargetType,
data::CounterfactualData;
models::Dict{<:Any,<:AbstractModel},
generators::Dict{<:Any,<:AbstractGenerator},
measure::Union{Function,Vector{<:Function}}=default_measures,
xids::Union{Nothing,AbstractArray}=nothing,
dataname::Union{Nothing,Symbol,String}=nothing,
verbose::Bool=true,
store_ce::Bool=false,
parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
kwrgs...,
)First generates counterfactual explanations for factual x, the target and data using each of the provided models and generators. Then generates a Benchmark for the vector of counterfactual explanations as in benchmark(counterfactual_explanations::Vector{CounterfactualExplanation}).
CounterfactualExplanations.Evaluation.benchmark β Method
benchmark(
counterfactual_explanations::Vector{CounterfactualExplanation};
meta_data::Union{Nothing,<:Vector{<:Dict}}=nothing,
measure::Union{Function,Vector{<:Function}}=default_measures,
store_ce::Bool=false,
)Generates a Benchmark for a vector of counterfactual explanations. Optionally meta_data describing each individual counterfactual explanation can be supplied. This should be a vector of dictionaries of the same length as the vector of counterfactuals. If no meta_data is supplied, it will be automatically inferred. All measure functions are applied to each counterfactual explanation. If store_ce=true, the counterfactual explanations are stored in the benchmark.
CounterfactualExplanations.Evaluation.concatenate_benchmarks β Method
concatenate_benchmarks(storage_path::String)Concatenates all benchmarks stored in storage_path into a single benchmark.
CounterfactualExplanations.Evaluation.evaluate β Function
evaluate(
ce::CounterfactualExplanation,
meta_data::Union{Nothing,Dict}=nothing;
measure::Union{Function,Vector{<:Function}}=default_measures,
agg::Function=mean,
report_each::Bool=false,
output_format::Symbol=:Vector,
pivot_longer::Bool=true,
store_ce::Bool=false,
report_meta::Bool=false,
)Just computes evaluation measures for the counterfactual explanation. By default, no meta data is reported. For report_meta=true, meta data is automatically inferred, unless this overwritten by meta_data. The optional meta_data argument should be a vector of dictionaries of the same length as the vector of counterfactual explanations.
Arguments:
ce: The counterfactual explanation to evaluate.meta_data: A vector of dictionaries containing meta data for each counterfactual explanation. If not provided, the default meta data is inferred from the counterfactual explanations.measure: The evaluation measures to compute. By default, all available measures are computed.agg: The aggregation function to use for the evaluation measures. By default, the mean is used.report_each: If true, each evaluation measure is reported separately. Otherwise, the mean of all measures is reported.output_format: The format of the output. By default, a vector is returned.pivot_longer: If true, the evaluation measures are pivoted longer. Otherwise, they are stacked.store_ce: If true, the counterfactual explanation is stored in the evaluation DataFrame. Note: These objects are potentially large and can consume a lot of memory.report_meta: If true, meta data is reported. Otherwise, it is not.
CounterfactualExplanations.Evaluation.faithfulness β Method
faithfulness(
ce::CounterfactualExplanation,
fun::typeof(Objectives.distance_from_target);
Ξ»::AbstractFloat=1.0,
kwrgs...,
)Computes the faithfulness of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the model posterior through SGLD (see distance_from_posterior).
CounterfactualExplanations.Evaluation.feature_sensitivity β Function
feature_sensitivity(
ce::AbstractCounterfactualExplanation, d::Union{Int,Vector{Int}}=[1]; kwrgs...
)Return the sensitivity to feature(s) d in terms of absolute changes associated with the counterfactual. Any keyword arguments accepted by distance can be passed to kwrgs....
CounterfactualExplanations.Evaluation.get_global_ce_transform β Method
Get the global CE transformer.
CounterfactualExplanations.Evaluation.global_ce_transform β Method
global_ce_transform(transformer::AbstractCETransformer)Sets the global CE transformer to transformer.
CounterfactualExplanations.Evaluation.global_output_identifier β Method
global_output_identifier(identifier::AbstractOutputIdentifier)Set the global output identifier to identifier and return its string representation. The global output identifier is used by default for all serialization operations.
CounterfactualExplanations.Evaluation.global_serializer β Method
global_serializer(serializer::AbstractSerializer)Set the global serializer to serializer and return its state. The global serializer is used by default for all serialization operations.
CounterfactualExplanations.Evaluation.mmd_null_dist β Function
mmd_null_dist(
x::AbstractArray, y::AbstractArray, k::KernelFunctions.Kernel=default_kernel; l=1000
)Compute the null distribution of MMD for two samples x and y through bootstrapping as follows:
- For each bootstrap sample, shuffle the columns of
xandy. - Compute the MMD between the shuffled samples.
- Repeat this process
ltimes to obtain a null distribution of MMD values. - Return the null distribution of MMD values.
Under the null hypothesis x and y are actually from the same distribution.
CounterfactualExplanations.Evaluation.mmd_significance β Method
mmd_significance(mmd::Number, mmd_null_dist::AbstractArray)Compute the p-value of the MMD test as the proportion of MMD values in the null distribution that are greater than or equal to the observed MMD value.
CounterfactualExplanations.Evaluation.plausibility β Method
plausibility(
ce::CounterfactualExplanation,
fun::typeof(Objectives.distance_from_target);
K=nothing,
kwrgs...,
)Computes the plausibility of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the target distribution.
CounterfactualExplanations.Evaluation.plausibility β Method
plausibility(
ce::CounterfactualExplanation,
fun::typeof(Objectives.distance_from_target);
K=nothing,
kwrgs...,
)Computes the plausibility of a counterfactual explanation based on the distance between the counterfactual and samples drawn from the target distribution.
CounterfactualExplanations.Evaluation.plausibility β Method
plausibility(
ce::CounterfactualExplanation,
fun::typeof(Objectives.distance_from_target_cosine);
K=nothing,
kwrgs...,
)Computes the plausibility of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the target distribution.
CounterfactualExplanations.Evaluation.plausibility_cosine β Method
plausibility_cosine(ce::CounterfactualExplanation; kwrgs...)Computes the plausibility of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the target distribution.
CounterfactualExplanations.Evaluation.plausibility_distance_from_target β Method
plausibility_distance_from_target(ce::CounterfactualExplanation; kwrgs...)Computes the plausibility of a counterfactual explanation based on the distance between the counterfactual and samples drawn from the target distribution.
CounterfactualExplanations.Evaluation.plausibility_energy_differential β Method
plausibility_energy_differential(ce::CounterfactualExplanation; kwrgs...)Computes the plausibility of a counterfactual explanation based on the energy differential between the counterfactual and samples drawn from the target distribution.
CounterfactualExplanations.Evaluation.redundancy β Method
redundancy(ce::CounterfactualExplanation)Computes the feature redundancy: that is, the number of features that remain unchanged from their original, factual values.
CounterfactualExplanations.Evaluation.validity β Method
validity(ce::CounterfactualExplanation; Ξ³=0.5)Checks of the counterfactual search has been successful in that the predicted label corresponds to the specified target. In case multiple counterfactuals were generated, the function returns the proportion of successful counterfactuals.
CounterfactualExplanations.DataPreprocessing.CounterfactualData β Method
CounterfactualData(
X::AbstractMatrix,
y::RawOutputArrayType;
mutability::Union{Vector{Symbol},Nothing}=nothing,
domain::Union{Any,Nothing}=nothing,
features_categorical::Union{Vector{Vector{Int}},Nothing}=nothing,
features_continuous::Union{Vector{Int},Nothing}=nothing,
input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer}=nothing,
)This outer constructor method prepares features X and labels y to be used with the package. Mutability and domain constraints can be added for the features. The function also accepts arguments that specify which features are categorical and which are continues. These arguments are currently not used.
Examples
using CounterfactualExplanations.Data
x, y = toy_data_linear()
X = hcat(x...)
counterfactual_data = CounterfactualData(X,y')CounterfactualExplanations.DataPreprocessing.CounterfactualData β Method
function CounterfactualData(
X::Tables.MatrixTable,
y::RawOutputArrayType;
kwrgs...
)Outer constructor method that accepts a Tables.MatrixTable. By default, the indices of categorical and continuous features are automatically inferred the features' scitype.
CounterfactualExplanations.DataPreprocessing.apply_domain_constraints β Method
apply_domain_constraints(counterfactual_data::CounterfactualData, x::AbstractArray)A subroutine that is used to apply the predetermined domain constraints.
CounterfactualExplanations.DataPreprocessing.fit_transformer! β Method
fit_transformer!(
data::CounterfactualData,
input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer};
kwargs...,
)Fit a transformer to the data in place.
CounterfactualExplanations.DataPreprocessing.fit_transformer β Method
fit_transformer(data::CounterfactualData, input_encoder::Nothing; kwargs...)Fit a transformer to the data. This is a no-op if input_encoder is Nothing.
CounterfactualExplanations.DataPreprocessing.fit_transformer β Method
fit_transformer(
data::CounterfactualData,
input_encoder::Type{<:CausalInference.SCM};
kwargs...,
)Fit a transformer to the data for a SCM object.
CounterfactualExplanations.DataPreprocessing.fit_transformer β Method
fit_transformer(
data::CounterfactualData,
input_encoder::Type{GenerativeModels.AbstractGenerativeModel};
kwargs...,
)Fit a transformer to the data for a GenerativeModels.AbstractGenerativeModel object.
CounterfactualExplanations.DataPreprocessing.fit_transformer β Method
fit_transformer(
data::CounterfactualData,
input_encoder::Type{MultivariateStats.AbstractDimensionalityReduction};
kwargs...,
)Fit a transformer to the data for a MultivariateStats.AbstractDimensionalityReduction object.
CounterfactualExplanations.DataPreprocessing.fit_transformer β Method
fit_transformer(
data::CounterfactualData,
input_encoder::Type{StatsBase.AbstractDataTransform};
kwargs...,
)Fit a transformer to the data for a StatsBase.AbstractDataTransform object.
CounterfactualExplanations.DataPreprocessing.fit_transformer β Method
fit_transformer(data::CounterfactualData, input_encoder::InputTransformer; kwargs...)Fit a transformer to the data for an InputTransformer object. This is a no-op.
CounterfactualExplanations.DataPreprocessing.mutability_constraints! β Method
mutability_constraints!(counterfactual_data::CounterfactualData, mutability)Applies provided mutability constraints in-place to existing data.
CounterfactualExplanations.DataPreprocessing.mutability_constraints β Method
mutability_constraints(counterfactual_data::CounterfactualData, mutability::Nothing)If nothing is supplied, all features are assumed to be mutable in both directions.
CounterfactualExplanations.DataPreprocessing.mutability_constraints β Method
mutabilityconstraints(counterfactualdata::CounterfactualData, mutability::Vector{Symbol})
If mutability is already a vector of integers, these are assumed to be indices of immutable features. All other features are assumed to be mutable in both directions. This offers a convenient way to quickly specify immutable features.
CounterfactualExplanations.DataPreprocessing.mutability_constraints β Method
mutabilityconstraints(counterfactualdata::CounterfactualData, mutability::Vector{Symbol})
If mutability is already a vector of symbols, it is returned as is.
CounterfactualExplanations.DataPreprocessing.mutability_constraints β Method
mutability_constraints(counterfactual_data::CounterfactualData)A convenience function that returns the mutability constraints. If none were specified, it is assumed that all features are mutable in :both directions.
CounterfactualExplanations.DataPreprocessing.mutability_constraints β Method
mutability_constraints(counterfactual_data::CounterfactualData, mutability::Pair{K,V}...) where {K<:Int,V<:Symbol}If pairs of feature indices (::Int) and constraints (::Symbol) are supplied, the given constraints are applied to the corresponding features. All other features are assumed to be mutable in both directions.
CounterfactualExplanations.DataPreprocessing.select_factual β Method
select_factual(counterfactual_data::CounterfactualData, index::Int)A convenience method that can be used to access the feature matrix.
CounterfactualExplanations.DataPreprocessing.select_factual β Method
select_factual(counterfactual_data::CounterfactualData, index::Union{Vector{Int},UnitRange{Int}})A convenience method that can be used to access the feature matrix.
CounterfactualExplanations.DataPreprocessing.transformable_features β Method
transformable_features(counterfactual_data::CounterfactualData, input_encoder::Any)By default, all continuous features are transformable. This function returns the indices of all continuous features.
CounterfactualExplanations.DataPreprocessing.transformable_features β Method
transformable_features(
counterfactual_data::CounterfactualData, input_encoder::Type{CausalInference.SCM}
)Returns the indices of all features that have causal parents.
CounterfactualExplanations.DataPreprocessing.transformable_features β Method
transformable_features(
counterfactual_data::CounterfactualData, input_encoder::Type{ZScoreTransform}
)Returns the indices of all continuous features that can be transformed. For constant features ZScoreTransform returns NaN.
CounterfactualExplanations.DataPreprocessing.transformable_features β Method
transformable_features(counterfactual_data::CounterfactualData)Dispatches the transformable_features function to the appropriate method based on the type of the dt field.
CounterfactualExplanations.Models.all_models_catalogue β Constant
all_models_catalogueA dictionary containing both differentiable and non-differentiable machine learning models.
CounterfactualExplanations.Models.standard_models_catalogue β Constant
standard_models_catalogueA dictionary containing all differentiable machine learning models.
CounterfactualExplanations.AbstractModel β Method
(model::AbstractModel)(X::AbstractArray)When called on data x, logits are returned.
CounterfactualExplanations.Models.DeepEnsemble β Method
DeepEnsemble(model; likelihood::Symbol=:classification_binary)An outer constructor for a deep ensemble model.
CounterfactualExplanations.Models.Linear β Method
Linear(model; likelihood::Symbol=:classification_binary)An outer constructor for a linear model.
CounterfactualExplanations.Models.MLP β Method
MLP(model; likelihood::Symbol=:classification_binary)An outer constructor for a multi-layer perceptron (MLP) model.
CounterfactualExplanations.Models.Model β Type
Model <: AbstractModelConstructor for all models.
CounterfactualExplanations.Models.Model β Method
Model(model, type::AbstractFluxNN; likelihood::Symbol=:classification_binary)Overloaded constructor for Flux models.
CounterfactualExplanations.Models.Model β Method
Model(model, type::AbstractModelType; likelihood::Symbol=:classification_binary)Outer constructor for Model where the atomic model is defined and assumed to be pre-trained.
CounterfactualExplanations.Models.Model β Method
(M::Model)(data::CounterfactualData, type::DeepEnsemble; kwargs...)Constructs a deep ensemble for the given data.
CounterfactualExplanations.Models.Model β Method
(M::Model)(data::CounterfactualData, type::Linear; kwargs...)Constructs a model with one linear layer for the given data. If the output is binary, this corresponds to logistic regression, since model outputs are passed through the sigmoid function. If the output is multi-class, this corresponds to multinomial logistic regression, since model outputs are passed through the softmax function.
CounterfactualExplanations.Models.Model β Method
(M::Model)(data::CounterfactualData, type::MLP; kwargs...)Constructs a multi-layer perceptron (MLP) for the given data.
CounterfactualExplanations.Models.Model β Method
(M::Model)(data::CounterfactualData; kwargs...)Wrap model M around the data in data.
CounterfactualExplanations.Models.Model β Method
Model(type::AbstractModelType; likelihood::Symbol=:classification_binary)Outer constructor for Model where the atomic model is not yet defined.
CounterfactualExplanations.Models.fit_model β Function
fit_model(
counterfactual_data::CounterfactualData, model::Symbol=:MLP;
kwrgs...
)Fits one of the available default models to the counterfactual_data. The model argument can be used to specify the desired model. The available values correspond to the keys of the all_models_catalogue dictionary.
CounterfactualExplanations.Models.fit_model β Method
fit_model(
counterfactual_data::CounterfactualData, type::AbstractModelType; kwrgs...
)A wrapper function to fit a model to the counterfactual_data for a given type of model.
Arguments
counterfactual_data::CounterfactualData: The data to be used for training the model.type::AbstractModelType: The type of model to be trained, e.g.,MLP,DecisionTreeModel, etc.
Examples
julia> using CounterfactualExplanations
julia> using CounterfactualExplanations.Models
julia> using TaijaData
julia> data = CounterfactualData(load_linearly_separable()...);
julia> M = fit_model(data, Linear())
CounterfactualExplanations.Models.Model(Chain(Dense(2 => 2)), :classification_multi, CounterfactualExplanations.Models.Fitresult(Chain(Dense(2 => 2)), Dict{Any, Any}()), Linear())CounterfactualExplanations.Models.logits β Method
logits(M::Model, X::AbstractArray)Returns the logits of the model.
CounterfactualExplanations.Models.logits β Method
logits(M::Model, type::AbstractFluxNN, X::AbstractArray)Overloads the logits function for Flux models.
CounterfactualExplanations.Models.logits β Method
logits(M::Model, type::MLJModelType, X::AbstractArray)Overloads the logits method for MLJ models.
CounterfactualExplanations.Models.logits β Method
logits(M::Model, type::DeepEnsemble, X::AbstractArray)Overloads the logits function for deep ensembles.
CounterfactualExplanations.Models.model_evaluation β Method
model_evaluation(M::AbstractModel, test_data::CounterfactualData)Helper function to compute F-Score for AbstractModel on a (test) data set. By default, it computes the accuracy. Any other measure, e.g. from the StatisticalMeasures package, can be passed as an argument. Currently, only measures applicable to classification tasks are supported.
CounterfactualExplanations.Models.predict_label β Method
predict_label(M::AbstractModel, counterfactual_data::CounterfactualData, X::AbstractArray)Returns the predicted output label for a given model M, data set counterfactual_data and input data X.
CounterfactualExplanations.Models.predict_label β Method
predict_label(M::AbstractModel, counterfactual_data::CounterfactualData)Returns the predicted output labels for all data points of data set counterfactual_data for a given model M.
CounterfactualExplanations.Models.predict_proba β Method
predict_proba(M::AbstractModel, counterfactual_data::CounterfactualData, X::Union{Nothing,AbstractArray})Returns the predicted output probabilities for a given model M, data set counterfactual_data and input data X.
CounterfactualExplanations.Models.probs β Method
probs(M::Model, X::AbstractArray)Returns the probabilities of the model.
CounterfactualExplanations.Models.probs β Method
probs(M::Model, type::AbstractFluxNN, X::AbstractArray)Overloads the probs function for Flux models.
CounterfactualExplanations.Models.probs β Method
probs(
M::Model,
type::MLJModelType,
X::AbstractArray,
)Overloads the probs method for MLJ models.
To Do:
Refactor this to be less convoluted and bring in line with current MLJ API.
CounterfactualExplanations.Models.probs β Method
probs(M::Model, type::DeepEnsemble, X::AbstractArray)Overloads the probs function for deep ensembles.
CounterfactualExplanations.Generators.generator_catalogue β Constant
A dictionary containing the constructors of all available counterfactual generators.
CounterfactualExplanations.Generators.AbstractGradientBasedGenerator β Type
AbstractGradientBasedGeneratorAn abstract type that serves as the base type for gradient-based counterfactual generators.
CounterfactualExplanations.Generators.AbstractNonGradientBasedGenerator β Type
AbstractNonGradientBasedGeneratorAn abstract type that serves as the base type for non gradient-based counterfactual generators.
CounterfactualExplanations.Generators.FeatureTweakGenerator β Type
Feature Tweak counterfactual generator class.
CounterfactualExplanations.Generators.FeatureTweakGenerator β Method
FeatureTweakGenerator(; penalty::Union{Nothing,Function,Vector{Function}}=Objectives.distance_l2, Ο΅::AbstractFloat=0.1)Constructs a new Feature Tweak Generator object.
Uses the L2-norm as the penalty to measure the distance between the counterfactual and the factual. According to the paper by Tolomei et al., another recommended choice for the penalty in addition to the L2-norm is the L0-norm. The L0-norm simply minimizes the number of features that are changed through the tweak.
Arguments
penalty::Penalty: The penalty function to use for the generator. Defaults todistance_l2.Ο΅::AbstractFloat: The tolerance value for the feature tweaks. Described at length in Tolomei et al. (https://arxiv.org/pdf/1706.06691.pdf). Defaults to 0.1.
Returns
generator::FeatureTweakGenerator: A non-gradient-based generator that can be used to generate counterfactuals using the feature tweak method.
CounterfactualExplanations.Generators.GradientBasedGenerator β Type
Base class for gradient-based counterfactual generators.
CounterfactualExplanations.Generators.GradientBasedGenerator β Method
GradientBasedGenerator(;
loss::Union{Nothing,Function}=nothing,
penalty::Penalty=nothing,
Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing,
latent_space::Bool::false,
opt::Optimisers.AbstractRule=Optimisers.Descent(),
generative_model_params::NamedTuple=(;),
)Default outer constructor for GradientBasedGenerator.
Arguments
loss::Union{Nothing,Function}=nothing: The loss function used by the model.penalty::Penalty=nothing: A penalty function for the generator to penalize counterfactuals too far from the original point.Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing: The weight of the penalty function.latent_space::Bool=false: Whether to use the latent space of a generative model to generate counterfactuals.opt::Optimisers.AbstractRule=Optimisers.Descent(): The optimizer to use for the generator.generative_model_params::NamedTuple: The parameters of the generative model associated with the generator.
Returns
generator::GradientBasedGenerator: A gradient-based counterfactual generator.
CounterfactualExplanations.Generators.JSMADescent β Type
An optimisation rule that can be used to implement a Jacobian-based Saliency Map Attack.
CounterfactualExplanations.Generators.JSMADescent β Method
Outer constructor for the JSMADescent rule.
CounterfactualExplanations.Generators.TCRExGenerator β Type
T-CREx counterfactual generator class.
CounterfactualExplanations.Generators.CLUEGenerator β Method
Constructor for CLUEGenerator. For details, see Antoran et al. (2021).
CounterfactualExplanations.Generators.ClaPROARGenerator β Method
Constructor for ClaPGenerator. For details, see Altmeyer et al. (2023).
CounterfactualExplanations.Generators.DiCEGenerator β Method
Constructor for DiCEGenerator. For details, see Mothilal et al. (2020).
CounterfactualExplanations.Generators.ECCoGenerator β Method
Constructor for ECCoGenerator. This corresponds to the generator proposed in https://arxiv.org/abs/2312.10648, without the conformal set size penalty. For details, see Altmeyer et al. (2024).
CounterfactualExplanations.Generators.GenericGenerator β Method
Constructor for GenericGenerator.
CounterfactualExplanations.Generators.GravitationalGenerator β Method
Constructor for GravitationalGenerator. For details, see Altmeyer et al. (2023).
CounterfactualExplanations.Generators.GreedyGenerator β Method
Constructor for GreedyGenerator. For details, see Schut et al. (2021).
CounterfactualExplanations.Generators.ProbeGenerator β Method
Constructor for ProbeGenerator. For details, see Pawelczyk et al. (2022).
Warning
For details, see Pawelczyk et al. (2022).
CounterfactualExplanations.Generators.REVISEGenerator β Method
Constructor for REVISEGenerator. For details, see Joshi et al. (2019).
CounterfactualExplanations.Generators.WachterGenerator β Method
Constructor for WachterGenerator. For details, see Wachter et al. (2018).
CounterfactualExplanations.Generators.generate_perturbations β Method
generate_perturbations(
generator::AbstractGenerator, ce::AbstractCounterfactualExplanation
)The default method to generate feature perturbations for any generator.
CounterfactualExplanations.Generators.generate_perturbations β Method
generate_perturbations(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)The default method to generate feature perturbations for gradient-based generators through simple gradient descent.
CounterfactualExplanations.Generators.@objective β Macro
objective(generator, ex)A macro that can be used to define the counterfactual search objective.
CounterfactualExplanations.Generators.@search_feature_space β Macro
search_feature_space(generator)A simple macro that can be used to specify feature space search.
CounterfactualExplanations.Generators.@search_latent_space β Macro
search_latent_space(generator)A simple macro that can be used to specify latent space search.
CounterfactualExplanations.Generators.@with_optimiser β Macro
with_optimiser(generator, optimiser)A simple macro that can be used to specify the optimiser to be used.
CounterfactualExplanations.Objectives.choose_ad_backend β Method
choose_ad_backend(backends::Vararg{<:Any})Select a single automatic differentiation backend from multiple provided backends.
Arguments
backends::Vararg{<:Any}: Variable number of AD backend instances (e.g.,AutoZygote(),AutoEnzyme()).
Returns
AutoZygote()if all provided backends areAutoZygoteinstances.- The single non-
AutoZygotebackend if exactly one is provided amongAutoZygoteinstances.
Throws
AssertionErrorif more than one non-AutoZygotebackend is provided.
Examples
choose_ad_backend(AutoZygote(), AutoZygote()) # Returns AutoZygote()
choose_ad_backend(AutoZygote(), AutoEnzyme()) # Returns AutoEnzyme()
choose_ad_backend(AutoZygote(), AutoEnzyme(), AutoForwardDiff()) # Throws AssertionErrorCounterfactualExplanations.Objectives.choose_ad_backend β Method
choose_ad_backend(ce::CounterfactualExplanation)Select a compatible automatic differentiation backend for a counterfactual explanation.
Determines the AD backend by querying the backends of the model and generator within the counterfactual explanation, then reconciling them into a single compatible backend.
Arguments
ce::CounterfactualExplanation: A counterfactual explanation object containing a model (M) and a generator.
Returns
- A single AD backend compatible with both the model and generator.
Throws
AssertionErrorif the model and generator use incompatible non-AutoZygotebackends.
See Also
CounterfactualExplanations.Objectives.choose_ad_backend β Method
choose_ad_backend(gen::AbstractGenerator)Choose an appropriate automatic differentiation backend for a given generator based on its penalty function. Handles both simple cases (no penalty function) and complex cases where multiple penalties might require different AD backends, ensuring that these backends are mutually exclusive.
CounterfactualExplanations.Objectives.choose_ad_backend β Method
Chooses the appropriate AD backend for the given model type
CounterfactualExplanations.Objectives.ddp_diversity β Method
ddp_diversity(
ce::AbstractCounterfactualExplanation;
perturbation_size=1e-5
)Evaluates how diverse the counterfactuals are using a Determinantal Point Process (DDP).
CounterfactualExplanations.Objectives.ddp_diversity β Method
ddp_diversity(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.distance β Method
distance(
cf::AbstractArray,
from::AbstractArray;
agg=mean,
p::Real=1,
weights::Union{Nothing,AbstractArray}=nothing,
cosine::Bool=false,
d::Union{Nothing,Vector{Int}}=nothing,
)Computes the distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance β Method
distance(ce::AbstractCounterfactualExplanation; kwrgs...)Overloads method to be applied directly to ce
CounterfactualExplanations.Objectives.distance_cosine β Method
distance_cosine(ce::AbstractCounterfactualExplanation)Computes the distance of the counterfactual to the original factual using cosine similarity. See also: cos_dist.
CounterfactualExplanations.Objectives.distance_cosine β Method
distance_cosine(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.distance_from_target β Method
distance_from_target(
ce::AbstractCounterfactualExplanation;
K::Int=50
)Computes the distance of the counterfactual from samples in the target main. If choose_randomly is true, the function will randomly sample K neighbours from the target manifold. Otherwise, it will compute the pairwise distances and select the K closest neighbours.
Arguments
ce::AbstractCounterfactualExplanation: The counterfactual explanation.K::Int=50: The number of neighbours to sample.choose_randomly::Bool=true: Whether to sample neighbours randomly.kwrgs...: Additional keyword arguments for the distance function.
Returns
Ξ::AbstractFloat: The distance from the counterfactual to the target manifold.
CounterfactualExplanations.Objectives.distance_from_target β Method
distance_from_target(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.distance_from_target_cosine β Method
distance_from_target_cosine(ce::AbstractCounterfactualExplanation;kwrgs...)Compute the distance from a counterfactual to the target manifold using cosine similarity. See also: cos_dist.
Arguments
ce::AbstractCounterfactualExplanation: The counterfactual explanation object.kwrgs...: Additional keyword arguments for the distance function.
CounterfactualExplanations.Objectives.distance_from_target_cosine β Method
distance_from_target_cosine(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.distance_l0 β Method
distance_l0(ce::AbstractCounterfactualExplanation)Computes the L0 distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_l0 β Method
distance_l0(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.distance_l1 β Method
distance_l1(ce::AbstractCounterfactualExplanation)Computes the L1 distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_l1 β Method
distance_l1(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.distance_l2 β Method
distance_l2(ce::AbstractCounterfactualExplanation)Computes the L2 (Euclidean) distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_l2 β Method
distance_l2(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.distance_linf β Method
distance_linf(ce::AbstractCounterfactualExplanation)Computes the L-inf distance of the counterfactual to the original factual.
CounterfactualExplanations.Objectives.distance_linf β Method
distance_linf(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.distance_mad β Method
distance_mad(ce::AbstractCounterfactualExplanation; agg=mean)This is the distance measure proposed by Wachter et al. (2017).
CounterfactualExplanations.Objectives.distance_mad β Method
distance_mad(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.hinge_loss β Method
hinge_loss(ce::AbstractCounterfactualExplanation)Calculates the hinge loss of a counterfactual explanation with InvalidationRateConvergence.
Arguments
ce::AbstractCounterfactualExplanation: The counterfactual explanation to calculate the hinge loss for.
Returns
The hinge loss of the counterfactual explanation.
CounterfactualExplanations.Objectives.hinge_loss β Method
hinge_loss(ce::AbstractCounterfactualExplanation)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.predictive_entropy β Method
predictive_entropy(ce::AbstractCounterfactualExplanation; agg=Statistics.mean)Computes the predictive entropy of the counterfactuals. Explained in https://arxiv.org/abs/1406.2541.
Flux.Losses.logitbinarycrossentropy β Method
Flux.Losses.logitbinarycrossentropy(ce::AbstractCounterfactualExplanation)Simply extends the logitbinarycrossentropy method to work with objects of type AbstractCounterfactualExplanation.
Flux.Losses.logitcrossentropy β Method
Flux.Losses.logitcrossentropy(ce::AbstractCounterfactualExplanation)Simply extends the logitcrossentropy method to work with objects of type AbstractCounterfactualExplanation.
Internal functions
CounterfactualExplanations.CRE β Type
CRE <: AbstractCounterfactualExplanationA Counterfactual Rule Explanation (CRE) is a global explanation for a given target, model M, data and generator.
CounterfactualExplanations.CRE β Method
(cre::CRE)(x::AbstractArray)Generates a local counterfactual point explanation for x using the generator.
CounterfactualExplanations.DecisionTreeModel β Type
DecisionTreeModelConcrete type for tree-based models from DecisionTree.jl. Since DecisionTree.jl has an MLJ interface, we subtype the MLJModelType model type.
CounterfactualExplanations.FluxModelParams β Type
FluxModelParamsDefault MLP training parameters.
CounterfactualExplanations.JEM β Type
JEMConcrete type for joint-energy models from JointEnergyModels. Since JointEnergyModels has an MLJ interface, we subtype the MLJModelType model type.
CounterfactualExplanations.LaplaceReduxModel β Type
LaplaceReduxModelConcrete type for neural networks with Laplace Approximation from the LaplaceRedux package. Currently subtyping the AbstractFluxNN model type, although this may be changed to MLJ in the future.
CounterfactualExplanations.NeuroTreeModel β Type
NeuroTreeModelConcrete type for differentiable tree-based models from NeuroTreeModels. Since NeuroTreeModels has an MLJ interface, we subtype the MLJModelType model type.
CounterfactualExplanations.Objectives.ADRequirements β Method
The LaplaceReduxModel model type requires ForwardDiff
CounterfactualExplanations.RandomForestModel β Type
RandomForestModelConcrete type for random forest model from DecisionTree.jl. Since the DecisionTree package has an MLJ interface, we subtype the MLJModelType model type.
CounterfactualExplanations.Rule β Type
RuleA Rule is just a list of bounds for the different features. See also CRE.
Base.Broadcast.broadcastable β Method
Treat AbstractGenerator as scalar when broadcasting.
Base.Broadcast.broadcastable β Method
Treat AbstractModel as scalar when broadcasting.
Base.Broadcast.broadcastable β Method
Treat AbstractPenalty as scalar when broadcasting.
CounterfactualExplanations.adjust_shape! β Method
adjust_shape!(ce::CounterfactualExplanation)A convenience method that adjusts the dimensions of the counterfactual state and related fields.
CounterfactualExplanations.adjust_shape β Method
adjust_shape(
ce::CounterfactualExplanation,
x::AbstractArray
)A convenience method that adjusts the dimensions of x.
CounterfactualExplanations.already_in_target_class β Method
already_in_target_class(ce::CounterfactualExplanation)Check if the factual is already in the target class.
CounterfactualExplanations.apply_domain_constraints! β Method
apply_domain_constraints!(ce::CounterfactualExplanation)Wrapper function that applies underlying domain constraints.
CounterfactualExplanations.apply_mutability β Method
apply_mutability(
ce::CounterfactualExplanation,
grad_ce_state::AbstractArray,
)A subroutine that applies mutability constraints to the proposed vector of feature perturbations.
CounterfactualExplanations.counterfactual β Method
counterfactual(ce::AbstractCounterfactualExplanation)A convenience method that returns the counterfactual.
CounterfactualExplanations.counterfactual_label β Method
counterfactual_label(ce::CounterfactualExplanation)A convenience method that returns the predicted label of the counterfactual.
CounterfactualExplanations.counterfactual_label_path β Method
counterfactual_label_path(ce::CounterfactualExplanation)Returns the counterfactual labels for each step of the search.
CounterfactualExplanations.counterfactual_probability β Function
counterfactual_probability(ce::CounterfactualExplanation)A convenience method that computes the class probabilities of the counterfactual.
CounterfactualExplanations.counterfactual_probability_path β Method
counterfactual_probability_path(ce::CounterfactualExplanation)Returns the counterfactual probabilities for each step of the search.
CounterfactualExplanations.decode_array β Method
decode_array(
data::CounterfactualData,
dt::CausalInference.SCM,
x::AbstractArray,
)Helper function to decode an array x using a data transform dt::GenerativeModels.AbstractGenerativeModel.
CounterfactualExplanations.decode_array β Method
decode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)Helper function to decode an array x using a data transform dt::GenerativeModels.AbstractGenerativeModel.
CounterfactualExplanations.decode_array β Method
decode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)Helper function to decode an array x using a data transform dt::MultivariateStats.AbstractDimensionalityReduction.
CounterfactualExplanations.decode_array β Method
decode_array(dt::Nothing, x::AbstractArray)Helper function to decode an array x using a data transform dt::Nothing. This is a no-op.
CounterfactualExplanations.decode_array β Method
decode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)Helper function to decode an array x using a data transform dt::StatsBase.AbstractDataTransform.
CounterfactualExplanations.decode_state β Function
function decode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing, )
Applies all the applicable decoding functions:
- If applicable, map the state variable back from the latent space to the feature space.
- If and where applicable, inverse-transform features.
- Reconstruct all categorical encodings.
Finally, the decoded counterfactual is returned.
CounterfactualExplanations.decode_state! β Function
decode_state!(ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)In-place version of decode_state.
CounterfactualExplanations.encode_array β Method
encode_array(data::CounterfactualData, dt::CausalInference.SCM, x::AbstractArray)Helper function to encode an array x using a data transform dt::CausalInference.SCM. This is a no-op.
CounterfactualExplanations.encode_array β Method
encode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)Helper function to encode an array x using a data transform dt::GenerativeModels.AbstractGenerativeModel.
CounterfactualExplanations.encode_array β Method
encode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)Helper function to encode an array x using a data transform dt::MultivariateStats.AbstractDimensionalityReduction.
CounterfactualExplanations.encode_array β Method
encode_array(dt::Nothing, x::AbstractArray)Helper function to encode an array x using a data transform dt::Nothing. This is a no-op.
CounterfactualExplanations.encode_array β Method
encode_array(dt::StatsBase.AbstractDataTransform, x::AbstractArray)Helper function to encode an array x using a data transform dt::StatsBase.AbstractDataTransform.
CounterfactualExplanations.encode_state β Function
function encode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing} = nothing, )
Applies all required encodings to x:
- If applicable, it maps
xto the latent space learned by the generative model. - If and where applicable, it rescales features.
Finally, it returns the encoded state variable.
CounterfactualExplanations.encode_state! β Function
encode_state!(ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing)In-place version of encode_state.
CounterfactualExplanations.factual β Method
factual(ce::AbstractCounterfactualExplanation)A convenience method to retrieve the factual x.
CounterfactualExplanations.factual_label β Method
factual_label(ce::CounterfactualExplanation)A convenience method to get the predicted label associated with the factual.
CounterfactualExplanations.factual_probability β Method
factual_probability(ce::CounterfactualExplanation)A convenience method to compute the class probabilities of the factual.
CounterfactualExplanations.find_potential_neighbours β Function
find_potential_neighbours(ce::CounterfactualExplanation, n::Int=1000)Overloads the function for CounterfactualExplanation to use the counterfactual data's labels if no data is provided.
CounterfactualExplanations.find_potential_neighbours β Function
find_potential_neighbours(
ce::AbstractCounterfactualExplanation, data::CounterfactualData, n::Int=1000
)Finds potential neighbors for the selected factual data point.
CounterfactualExplanations.get_meta β Method
get_meta(ce::CounterfactualExplanation)Returns meta data for a counterfactual explanation.
CounterfactualExplanations.guess_likelihood β Method
guess_likelihood(y::RawOutputArrayType)Guess the likelihood based on the scientific type of the output array. Returns a symbol indicating the guessed likelihood and the scientific type of the output array.
CounterfactualExplanations.guess_loss β Method
guess_loss(ce::CounterfactualExplanation)Guesses the loss function to be used for the counterfactual search in case likelihood field is specified for the AbstractModel instance and no loss function was explicitly declared for AbstractGenerator instance.
CounterfactualExplanations.initialize! β Method
initialize!(ce::CounterfactualExplanation)Initializes the counterfactual explanation. This method is called by the constructor. It does the following:
- Creates a dictionary to store information about the search.
- Initializes the counterfactual state.
- Initializes the search path.
- Initializes the loss.
CounterfactualExplanations.initialize_state! β Method
initialize_state!(ce::CounterfactualExplanation)Initializes the starting point for the factual(s) in-place.
CounterfactualExplanations.initialize_state β Method
initialize_state(ce::CounterfactualExplanation)Initializes the starting point for the factual(s):
- If
ce.initializationis set to:identityor counterfactuals are searched in a latent space, then nothing is done. - If
ce.initializationis set to:add_perturbation, then a random perturbation is added to the factual following following Slack (2021): https://arxiv.org/abs/2106.02666. The authors show that this improves adversarial robustness.
CounterfactualExplanations.outdim β Method
outdim(ce::CounterfactualExplanation)A convenience method that returns the output dimension of the predictive model.
CounterfactualExplanations.polynomial_decay β Method
polynomial_decay(a::Real, b::Real, decay::Real, t::Int)Computes the polynomial decay function as in Welling et al. (2011): https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf.
CounterfactualExplanations.reset! β Method
reset!(flux_training_params::FluxModelParams)Restores the default parameter values.
CounterfactualExplanations.steps_exhausted β Method
steps_exhausted(ce::AbstractCounterfactualExplanation)A convenience method that checks if the number of maximum iterations has been exhausted.
CounterfactualExplanations.target_probs_path β Method
target_probs_path(ce::CounterfactualExplanation)Returns the target probabilities for each step of the search.
CounterfactualExplanations.update! β Method
update!(ce::CounterfactualExplanation)An important subroutine that updates the counterfactual explanation. It takes a snapshot of the current counterfactual search state and passes it to the generator. Based on the current state the generator generates perturbations. Various constraints are then applied to the proposed vector of feature perturbations. Finally, the counterfactual search state is updated.
CounterfactualExplanations.Convergence.max_iter β Method
max_iter(conv::AbstractConvergence)Returns the maximum number of iterations specified.
CounterfactualExplanations.Evaluation.AbstractCETransformer β Type
An abstract type for CE transformers.
CounterfactualExplanations.Evaluation.AbstractOutputIdentifier β Type
Abstract type for output identifiers.
CounterfactualExplanations.Evaluation.AbstractSerializer β Type
Abstract type for serializers.
CounterfactualExplanations.Evaluation.EnergySampler β Type
Base type that stores information relevant to energy-based posterior sampling from AbstractModel.
CounterfactualExplanations.Evaluation.EnergySampler β Method
EnergySampler(
model::AbstractModel,
πx::Distribution,
πy::Distribution,
input_size::Dims,
yidx::Int;
opt::Union{Nothing,AbstractSamplingRule}=nothing,
nsamples::Int=100,
niter_final::Int=1000,
ntransitions::Int=0,
opt_warmup::Union{Nothing,AbstractSamplingRule}=nothing,
niter::Int=20,
batch_size::Int=50,
prob_buffer::AbstractFloat=0.95,
kwargs...,
)Constructor for EnergySampler, which is used to sample from the posterior distribution of the model conditioned on y.
Arguments
model::AbstractModel: The model to be used for sampling.data::CounterfactualData: The data to be used for sampling.y::Any: The conditioning value.opt::AbstractSamplingRule=ImproperSGLD(): The sampling rule to be used. By default,SGLDis used witha = (2 / std(Uniform()) * std(πx)andb = 1andΞ³=0.9.nsamples::Int=100: The number of samples to include in the final empirical posterior distribution.niter_final::Int=1000: The number of iterations for generating samples from the posterior distribution. Typically, this number will be larger than the number of iterations during PMC training.ntransitions::Int=0: The number of transitions for (optionally) warming up the sampler. By default, this is set to 0 and the sampler is not warmed up. For valies larger than 0, the sampler is trained through PMC forniteriterations andntransitionstransitions to build a buffer of samples. The buffer is used for posterior sampling.opt_warmup::Union{Nothing,AbstractSamplingRule}=nothing: The sampling rule to be used for warm-up. By default,ImproperSGLDis used withΞ± = (2 / std(Uniform()) * std(πx)andΞ³ = 0.005Ξ±.niter::Int=100: The number of iterations for training the sampler through PMC.batch_size::Int=50: The batch size for training the sampler.prob_buffer::AbstractFloat=0.5: The probability of drawing samples from the replay buffer. Smaller values will result in more samples being drawn from the prior and typically lead to better mixing and diversity in the samples.kwargs...: Additional keyword arguments to be passed on to the sampler and PMC.
Returns
EnergySampler: An instance ofEnergySampler.
CounterfactualExplanations.Evaluation.EnergySampler β Method
EnergySampler(ce::CounterfactualExplanation; kwrgs...)Overloads the EnergySampler constructor to accept a CounterfactualExplanation object.
Base.rand β Function
Base.rand(sampler::EnergySampler, n::Int=100; retrain=false)Overloads the rand method to randomly draw n samples from EnergySampler. If from_posterior is true, the samples are drawn from the posterior distribution. Otherwise, the samples are generated from the model conditioned on the target value using a single chain (see generate_posterior_samples).
Arguments
sampler::EnergySampler: TheEnergySamplerobject to be used for sampling.n::Int=100: The number of samples to draw.from_posterior::Bool=true: Whether to draw samples from the posterior distribution.niter::Int=500: The number of iterations for generating samples through Monte Carlo sampling (single chain).
Returns
AbstractArray: The samples.
CounterfactualExplanations.Evaluation.SerializationState β Method
Serializer that does not allow serialization.
CounterfactualExplanations.Evaluation.SerializationState β Method
Serializer that allows serialization.
CounterfactualExplanations.Evaluation.TransformationFunction β Method
Transformation function for explicit transformer.
CounterfactualExplanations.Evaluation.TransformationFunction β Method
Transformation function for default transformer.
CounterfactualExplanations.Evaluation._ce_transform β Function
Global CE transformer.
CounterfactualExplanations.Evaluation.compute_measure β Method
compute_measure(ce::CounterfactualExplanation, measure::AbstractDivergenceMetric, agg::Function)For abstract divergence metrics, returns a vector of NaN values.
CounterfactualExplanations.Evaluation.compute_measure β Method
compute_measure(ce::CounterfactualExplanation, measure::Function, agg::Function)Computes a single measure for a counterfactual explanation. The measure is applied to the counterfactual explanation ce and aggregated using the aggregation function agg.
CounterfactualExplanations.Evaluation.define_prior β Method
define_prior(
data::CounterfactualData;
πx::Union{Nothing,Distribution}=nothing,
πy::Union{Nothing,Distribution}=nothing,
n_std::Int=3,
)Defines the prior for the data. The space is defined as a uniform distribution with bounds defined by the mean and standard deviation of the data. The bounds are extended by n_std standard deviations.
Arguments
data::CounterfactualData: The data to be used for defining the prior sampling space.n_std::Int=3: The number of standard deviations to extend the bounds.
Returns
Uniform: The uniform distribution defining the prior sampling space.
CounterfactualExplanations.Evaluation.distance_from_posterior β Method
distance_from_posterior(ce::AbstractCounterfactualExplanation)Computes the distance from the counterfactual to generated conditional samples. The distance is computed as the mean distance from the counterfactual to the samples drawn from the posterior distribution of the model. By default, the cosine distance is used.
Arguments
ce::AbstractCounterfactualExplanation: The counterfactual explanation object.nsamples::Int=1000: The number of samples to draw.from_posterior::Bool=true: Whether to draw samples from the posterior distribution.agg: The aggregation function to use for computing the distance.choose_lowest_energy::Bool=true: Whether to choose the samples with the lowest energy.choose_random::Bool=false: Whether to choose random samples.nmin::Int=25: The minimum number of samples to choose.p::Int=1: The norm to use for computing the distance.cosine::Bool=true: Whether to use the cosine distance.kwargs...: Additional keyword arguments to be passed on to theEnergySampler.
Returns
AbstractFloat: The distance from the counterfactual to the samples.
CounterfactualExplanations.Evaluation.generate_posterior_samples β Function
generate_posterior_samples(
e::EnergySampler, n::Int=1000; niter::Int=1000, kwargs...
)Generates n samples from the posterior distribution of the model conditioned on the target value y. The samples are generated through (Persistent) Monte Carlo sampling using the EnergySampler object. If the replay buffer is not empty, the initial samples are drawn from the buffer.
Note that by default the batch size of the sampler is set to round(Int, n / 100) by default for sampling. This is to ensure that the samples are drawn independently from the posterior distribution. It also helps to avoid vanishing gradients.
The chain is run persistently until n samples are generated. The number of transitions is set to ceil(Int, n / batch_size). Once the chain is run, the last n samples are form the replay buffer are returned.
Arguments
e::EnergySampler: TheEnergySamplerobject to be used for sampling.n::Int=100: The number of samples to generate.batch_size::Int=round(Int, n / 100): The batch size for sampling.niter::Int=1000: The number of iterations for generating samples from the posterior distribution.kwargs...: Additional keyword arguments to be passed on to the sampler.
Returns
AbstractArray: The generated samples.
CounterfactualExplanations.Evaluation.get_benchmark_files β Method
get_benchmark_files(storage_path::String)Returns a list of all benchmark files stored in storage_path.
CounterfactualExplanations.Evaluation.get_lowest_energy_sample β Method
get_lowest_energy_sample(sampler::EnergySampler; n::Int=5)Chooses the samples with the lowest energy (i.e. highest probability) from EnergySampler.
Arguments
sampler::EnergySampler: TheEnergySamplerobject to be used for sampling.n::Int=5: The number of samples to choose.
Returns
AbstractArray: The samples with the lowest energy.
CounterfactualExplanations.Evaluation.get_sampler! β Method
get_sampler!(ce::AbstractCounterfactualExplanation; kwargs...)Gets the EnergySampler object from the counterfactual explanation. If the sampler is not found, it is constructed and stored in the counterfactual explanation object.
CounterfactualExplanations.Evaluation.includes_divergence_metric β Method
includes_divergence_metric(measure::Union{Function,Vector{<:Function}})Checks if the provided measure includes a divergence metric.
CounterfactualExplanations.Evaluation.kernelsum β Method
kernelsum(k::KernelFunctions.Kernel, x::AbstractMatrix, y::AbstractMatrix)Compute the sum of kernel matrices between two matrices x and y. This function sums all kernel evaluations comparing columns in x to columns in y.
CounterfactualExplanations.Evaluation.kernelsum β Method
kernelsum(k::KernelFunctions.Kernel, x::AbstractMatrix)Compute the sum of kernel matrices between x and itself. This function sums all kernel evaluations comparing columns in x to each other and subtracts the trace of the resulting matrix to account for self-evaluations. The result is then divided by (m^2 - m) where m is the number of columns in x. This effectively gives you the mean of all pairwise kernel evaluations excluding self-evaluations.
CounterfactualExplanations.Evaluation.kernelsum β Method
kernelsum(k::KernelFunctions.Kernel, x::AbstractVector)Compute the sum of kernel matrices between x and itself where x is a vector. This function returns 0 for vectors since there are no pairs to compute the kernel matrix.
CounterfactualExplanations.Evaluation.needs_ce β Method
needs_ce(store_ce::Bool,measure::Union{Function,Vector{<:Function}})A helper function to determine if counterfactual explanations should be stored based on the given store_ce flag and the presence of a divergence metric in the measure.
CounterfactualExplanations.Evaluation.samplecolumns β Method
samplecolumns([rng::AbstractRNG], x::AbstractMatrix, n::Int)Sample n columns from a matrix. Returns x if the matrix has less than n columns.
CounterfactualExplanations.Evaluation.to_dataframe β Method
to_dataframe(
computed_measures::Vector,
measure,
report_each::Bool,
pivot_longer::Bool,
store_ce::Bool,
ce::CounterfactualExplanation,
)Evaluates a counterfactual explanation and returns a dataframe of evaluation measures.
CounterfactualExplanations.Evaluation.validity_strict β Method
validity_strict(ce::CounterfactualExplanation; kwrgs...)Checks if the counterfactual search has been strictly valid in the sense that it has converged with respect to the pre-specified target probability Ξ³.
CounterfactualExplanations.Evaluation.warmup! β Method
warmup!(
e::EnergySampler,
y::Int;
niter::Int=20,
ntransitions::Int=100,
kwargs...,
)Warms up the EnergySampler to the underlying model for conditioning value y. Specifically, this entails running PMC for niter iterations and ntransitions transitions to build a buffer of samples. The buffer is used for posterior sampling.
Arguments
e::EnergySampler: TheEnergySamplerobject to be trained.y::Int: The conditioning value.opt::Union{Nothing,AbstractSamplingRule}: The sampling rule to be used. By default,ImproperSGLDis used withΞ± = 2 * std(Uniform(πx))andΞ³ = 0.005Ξ±.niter::Int=20: The number of iterations for training the sampler through PMC.ntransitions::Int=100: The number of transitions for training the sampler. In each transition, the sampler is updated with a mini-batch of data. Data is either drawn from the replay buffer or reinitialized from the prior.kwargs...: Additional keyword arguments to be passed on to the sampler and PMC.
Returns
EnergySampler: The trainedEnergySampler.
CounterfactualExplanations.DataPreprocessing.InputTransformer β Type
InputTransformerAbstract type for data transformers. This can be any of the following:
StatsBase.AbstractDataTransform: A data transformation object from theStatsBasepackage.MultivariateStats.AbstractDimensionalityReduction: A dimensionality reduction object from theMultivariateStatspackage.GenerativeModels.AbstractGenerativeModel: A generative model object from theGenerativeModelsmodule.
CounterfactualExplanations.DataPreprocessing.TypedInputTransformer β Type
TypedInputTransformerAbstract type for data transformers.
Base.Broadcast.broadcastable β Method
Treat CounterfactualData as scalar when broadcasting.
CounterfactualExplanations.DataPreprocessing._subset β Method
_subset(data::CounterfactualData, idx::Vector{Int})Creates a subset of the data.
CounterfactualExplanations.DataPreprocessing.convert_to_1d β Method
convert_to_1d(y::AbstractMatrix, y_levels::AbstractArray)Helper function to convert a one-hot encoded matrix to a vector of labels. This is necessary because MLJ models require the labels to be represented as a vector, but the synthetic datasets in this package hold the labels in one-hot encoded form.
Arguments
y::Matrix: The one-hot encoded matrix.y_levels::AbstractArray: The levels of the categorical variable.
Returns
labels: A vector of labels.
CounterfactualExplanations.DataPreprocessing.input_dim β Method
input_dim(counterfactual_data::CounterfactualData)Helper function that returns the input dimension (number of features) of the data.
CounterfactualExplanations.DataPreprocessing.outdim β Method
outdim(data::CounterfactualData)Returns the number of output classes.
CounterfactualExplanations.DataPreprocessing.preprocess_data_for_mlj β Method
preprocess_data_for_mlj(data::CounterfactualData)Helper function to preprocess data::CounterfactualData for MLJ models.
Arguments
data::CounterfactualData: The data to be preprocessed.
Returns
- (
df_x,y): A tuple containing the preprocessed data, withdf_xbeing a DataFrame object andybeing a categorical vector.
Example
X, y = preprocessdatafor_mlj(data)
CounterfactualExplanations.DataPreprocessing.reconstruct_cat_encoding β Method
reconstruct_cat_encoding(counterfactual_data::CounterfactualData, x::Vector)Reconstruct the categorical encoding for a single instance.
CounterfactualExplanations.DataPreprocessing.subsample β Method
subsample(data::CounterfactualData, n::Int)Helper function to randomly subsample data::CounterfactualData.
CounterfactualExplanations.DataPreprocessing.train_test_split β Method
train_test_split(data::CounterfactualData;test_size=0.2,keep_class_ratio=false)Splits data into train and test split.
Arguments
data::CounterfactualData: The data to be preprocessed.test_size=0.2: Proportion of the data to be used for testing.keep_class_ratio=false: Decides whether to sample equally from each class, or keep their relative size.
Returns
- (
train_data::CounterfactualData,test_data::CounterfactualData): A tuple containing the train and test splits.
Example
train, test = traintestsplit(data, testsize=0.1, keepclass_ratio=true)
CounterfactualExplanations.DataPreprocessing.unpack_data β Method
unpack_data(data::CounterfactualData)Helper function that unpacks data.
CounterfactualExplanations.Models.AbstractCustomDifferentiableModel β Type
Base type for custom differentiable models.
CounterfactualExplanations.Models.AbstractDifferentiableModel β Type
Base type for differentiable models.
CounterfactualExplanations.Models.AbstractDifferentiableModelType β Type
Abstract types for differentiable models.
CounterfactualExplanations.Models.AbstractFluxModel β Type
Base type for differentiable models written in Flux.
CounterfactualExplanations.Models.AbstractFluxNN β Type
Abstract type for Flux models.
CounterfactualExplanations.Models.AbstractMLJModel β Type
Base type for differentiable models from the MLJ library.
CounterfactualExplanations.Models.AbstractModelType β Method
(type::AbstractModelType)(model; likelihood::Symbol=:classification_binary)Wrap model type around the pre-trained model model.
CounterfactualExplanations.Models.AbstractModelType β Method
(type::AbstractModelType)(data::CounterfactualData; kwargs...)Wrap model type around the data in data. This is a convenience function to avoid having to construct a Model object.
CounterfactualExplanations.Models.Differentiability β Type
A base type for model differentiability.
CounterfactualExplanations.Models.Differentiability β Method
Dispatches on the type of model for the differentiability trait.
CounterfactualExplanations.Models.Fitresult β Type
FitresultA struct to hold the results of fitting a model.
Fields
fitresult: The result of fitting the model to the data. This object should be callable on new data.other::Dict: A dictionary to hold any other relevant information.
CounterfactualExplanations.Models.Fitresult β Method
(fitresult::Fitresult)(newdata::AbstractArray)When called on new data, the Fitresult object returns the result of calling the fitresult on new data.
CounterfactualExplanations.Models.Fitresult β Method
(fitresult::Fitresult)()CounterfactualExplanations.Models.FluxNN β Type
Concrete type for Flux models.
CounterfactualExplanations.Models.IsDifferentiable β Type
Struct for models that are differentiable.
CounterfactualExplanations.Models.MLJModelType β Type
Abstract type for MLJ models.
CounterfactualExplanations.Models.NonDifferentiable β Type
By default, models are assumed not to be differentiable.
CounterfactualExplanations.Models.binary_to_onehot β Method
binary_to_onehot(p)Helper function to turn dummy-encoded variable into onehot-encoded variable.
CounterfactualExplanations.Models.build_ensemble β Method
build_ensemble(K::Int;kw=(input_dim=2,n_hidden=32,output_dim=1))Helper function that builds an ensemble of K models.
CounterfactualExplanations.Models.build_mlp β Method
build_mlp()Helper function to build simple MLP.
Examples
nn = build_mlp()CounterfactualExplanations.Models.data_loader β Method
data_loader(data::CounterfactualData)Prepares counterfactual data for training in Flux.
CounterfactualExplanations.Models.forward! β Method
forward!(model::Flux.Chain, data; loss::Symbol, opt::Symbol, n_epochs::Int=10, model_name="MLP")Forward pass for training a Flux.Chain model.
CounterfactualExplanations.Models.load_mnist_model β Method
load_mnist_model(type::AbstractModelType)Empty function to be overloaded for loading a pre-trained model for the AbstractModelType model type.
CounterfactualExplanations.Models.load_mnist_model β Method
load_mnist_model(type::DeepEnsemble)Load a pre-trained deep ensemble model for the MNIST dataset.
CounterfactualExplanations.Models.load_mnist_model β Method
load_mnist_model(type::MLP)Load a pre-trained MLP model for the MNIST dataset.
CounterfactualExplanations.Models.load_mnist_vae β Method
load_mnist_vae(; strong=true)Load a pre-trained VAE model for the MNIST dataset.
CounterfactualExplanations.Models.train β Method
train(M::Model, data::CounterfactualData)Trains the model M on the data in data.
CounterfactualExplanations.Models.train β Method
train(M::FluxModel, data::CounterfactualData; kwargs...)Wrapper function to train Flux models.
CounterfactualExplanations.Models.train β Method
train(
M::Model,
type::MLJModelType,
data::CounterfactualData,
)Overloads the train function for MLJ models.
CounterfactualExplanations.Models.train β Method
train(M::Model, type::DeepEnsemble, data::CounterfactualData; kwargs...)Overloads the train function for deep ensembles.
CounterfactualExplanations.GenerativeModels.AbstractGMParams β Type
Base type of generative model hyperparameter container.
CounterfactualExplanations.GenerativeModels.AbstractGenerativeModel β Type
Base type for generative model.
CounterfactualExplanations.GenerativeModels.Encoder β Type
Encoder
Constructs encoder part of VAE: a simple Flux neural network with one hidden layer and two linear output layers for the first two moments of the latent distribution.
CounterfactualExplanations.GenerativeModels.VAE β Type
VAE <: AbstractGenerativeModelConstructs the Variational Autoencoder. The VAE is a subtype of AbstractGenerativeModel. Any (sub-)type of AbstractGenerativeModel is accepted by latent space generators.
CounterfactualExplanations.GenerativeModels.VAE β Method
VAE(input_dim;kws...)Outer method for instantiating a VAE.
CounterfactualExplanations.GenerativeModels.VAEParams β Type
VAEParams <: AbstractGMParamsThe default VAE parameters describing both the encoder/decoder architecture and the training process.
CounterfactualExplanations.GenerativeModels.Decoder β Method
Decoder(input_dim::Int, latent_dim::Int, hidden_dim::Int; activation=relu)The default decoder architecture is just a Flux Chain with one hidden layer and a linear output layer.
CounterfactualExplanations.GenerativeModels.decode β Method
decode(generative_model::VAE, x::AbstractArray)Decodes an array x using the VAE decoder.
CounterfactualExplanations.GenerativeModels.encode β Method
encode(generative_model::VAE, x::AbstractArray)Encodes an array x using the VAE encoder. Specifically, it samples from the latent distribution. It does so by first passing x through the encoder to obtain the mean and log-variance of the latent distribution. Then, it samples from the latent distribution using the reparameterization trick. See Random.rand(encoder::Encoder, x, device=cpu) for more details.
CounterfactualExplanations.GenerativeModels.get_data β Method
get_data(X::AbstractArray, y::AbstractArray, batch_size)Preparing data for mini-batch training .
CounterfactualExplanations.GenerativeModels.get_data β Method
get_data(X::AbstractArray, batch_size)Preparing data for mini-batch training .
CounterfactualExplanations.GenerativeModels.reconstruct β Function
reconstruct(generative_model::VAE, x, device=cpu)Implements a full pass of some input x through the VAE: x β¦ xΜ.
CounterfactualExplanations.GenerativeModels.reparameterization_trick β Function
reparameterization_trick(ΞΌ,logΟ,device=cpu)
Helper function that implements the reparameterization trick: z βΌ π©(ΞΌ,ΟΒ²) β z=ΞΌ + Ο β Ξ΅, Ξ΅ βΌ π©(0,I).
CounterfactualExplanations.Generators.Penalty β Type
Type union for acceptable argument types for the penalty field of GradientBasedGenerator.
CounterfactualExplanations.Convergence.conditions_satisfied β Method
Convergence.conditions_satisfied(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)The default method to check if the all conditions for convergence of the counterfactual search have been satisified for gradient-based generators. By default, gradient-based search is considered to have converged as soon as the proposed feature changes for all features are smaller than one percent of its standard deviation.
CounterfactualExplanations.Generators._replace_nans β Function
_replace_nans(grad_ce_state::AbstractArray, old_new::Pair=(NaN => 0))Helper function to deal with exploding gradients. This is only a temporary fix and will be improved.
CounterfactualExplanations.Generators.grad_loss β Method
grad_loss(
generator::AbstractGradientBasedGenerator,
ce::AbstractCounterfactualExplanation
)The default method to compute the gradient of the loss function at the current counterfactual state for gradient-based generators.
CounterfactualExplanations.Generators.grad_pen β Method
grad_pen(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)The default method to compute the gradient of the complexity penalty at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl has gradient access.
If the penalty is not provided, it returns 0.0. By default, Zygote never works out the gradient for constants and instead returns 'nothing', so we need to add a manual step to override this behaviour. See here: https://discourse.julialang.org/t/zygote-gradient/26715.
CounterfactualExplanations.Generators.grad_search_opt β Method
grad_search_opt(
generator::AbstractGradientBasedGenerator,
ce::AbstractCounterfactualExplanation,
)The default method to compute the gradient of the counterfactual search objective for gradient-based generators. It simply computes the weighted sum over partial derivatives. It assumes that Zygote.jl has gradient access. If the counterfactual is being generated using Probe, the hinge loss is added to the gradient.
CounterfactualExplanations.Generators.h β Method
h(generator::AbstractGenerator, ce::AbstractCounterfactualExplanation)Dispatches to the appropriate complexity function for any generator.
CounterfactualExplanations.Generators.h β Method
h(generator::AbstractGenerator, penalty::Function, ce::AbstractCounterfactualExplanation)Overloads the h function for the case where a single penalty function is provided.
CounterfactualExplanations.Generators.h β Method
h(generator::AbstractGenerator, penalty::Nothing, ce::AbstractCounterfactualExplanation)Overloads the h function for the case where no penalty is provided.
CounterfactualExplanations.Generators.h β Method
h(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)Overloads the h function for the case where a single penalty function is provided with additional keyword arguments.
CounterfactualExplanations.Generators.h β Method
h(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)Overloads the h function for the case where a single penalty function is provided with additional keyword arguments.
CounterfactualExplanations.Generators.incompatible β Method
incompatible(AbstractGenerator, AbstractCounterfactualExplanation)Checks if the generator is incompatible with any of the additional specifications for the counterfactual explanations. By default, generators are assumed to be compatible.
CounterfactualExplanations.Generators.propose_state β Method
propose_state(
::Models.IsDifferentiable,
generator::AbstractGradientBasedGenerator,
ce::AbstractCounterfactualExplanation,
)Proposes new state based on backpropagation for gradient-based generators and differentiable models.
CounterfactualExplanations.Generators.total_loss β Method
total_loss(ce::AbstractCounterfactualExplanation)Computes the total loss of a counterfactual explanation with respect to the search objective.
CounterfactualExplanations.Objectives.ADRequirements β Type
A base type for AD backend requirements
CounterfactualExplanations.Objectives.ADRequirements β Method
The energy_constraint function requires ForwardDiff
CounterfactualExplanations.Objectives.ADRequirements β Method
The hinge_loss function requires ForwardDiff
CounterfactualExplanations.Objectives.NeedsForwardDiff β Type
This trait implies that ForwardDiff is required.
CounterfactualExplanations.Objectives.NeedsNeighbours β Type
Penalties that need access to neighbors in the target class.
CounterfactualExplanations.Objectives.NoADRequirements β Type
By default, no special AD backend is required.
CounterfactualExplanations.Objectives.NoPenaltyRequirements β Type
By default, penalties have no extra requirements.
CounterfactualExplanations.Objectives.PenaltyRequirements β Type
A base type for a style of process.
CounterfactualExplanations.Objectives.PenaltyRequirements β Method
The distance_from_target method needs neighbors in the target class.
CounterfactualExplanations.Objectives.cos_dist β Method
cos_dist(x,y)Computes the cosine distance between two vectors.
CounterfactualExplanations.Objectives.energy β Method
energy(M::AbstractModel, x::AbstractArray, t::Int)Computes the energy of the model at a given state as in Altmeyer et al. (2024): https://scholar.google.com/scholar?cluster=3697701546144846732&hl=en&as_sdt=0,5.
CounterfactualExplanations.Objectives.energy_constraint β Method
energy_constraint(
ce::AbstractCounterfactualExplanation;
agg=mean,
reg_strength::AbstractFloat=0.0,
decay::AbstractFloat=0.9,
kwargs...,
)Computes the energy constraint for the counterfactual explanation as in Altmeyer et al. (2024): https://scholar.google.com/scholar?cluster=3697701546144846732&hl=en&as_sdt=0,5. The energy constraint is a regularization term that penalizes the energy of the counterfactuals. The energy is computed as the negative logit of the target class.
Arguments
ce::AbstractCounterfactualExplanation: The counterfactual explanation.agg::Function=mean: The aggregation function (only applicable in casenum_counterfactuals > 1). Default ismean.reg_strength::AbstractFloat=0.0: The regularization strength.decay::AbstractFloat=0.9: The decay rate for the polynomial decay function (defaults to 0.9). Parameterais set to1.0 / ce.generator.opt.eta, such that the initial step size is equal to 1.0, not accounting forb. Parameterbis set toround(Int, max_steps / 20), wheremax_stepsis the maximum number of iterations.kwargs...: Additional keyword arguments.
Returns
β::AbstractFloat: The energy constraint.
CounterfactualExplanations.Objectives.energy_constraint β Method
energy_constraint(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.model_loss_penalty β Method
function model_loss_penalty(
ce::AbstractCounterfactualExplanation;
agg=mean
)Additional penalty for ClaPROARGenerator.
CounterfactualExplanations.Objectives.model_loss_penalty β Method
model_loss_penalty(ce::AbstractCounterfactualExplanation; kwargs...)Convenience method that computes cf from ce.
CounterfactualExplanations.Objectives.needs_neighbours β Method
needs_neighbours(ce::AbstractCounterfactualExplanation)Check if a counterfactual explanation needs access to neighbors in the target class.
CounterfactualExplanations.Objectives.needs_neighbours β Method
needs_neighbours(gen::AbstractGenerator)Check if a generator needs access to neighbors in the target class.
Extensions
DecisionTreeExt.AtomicDecisionTree β Type
Type union for DecisionTree decision tree classifiers and regressors.
DecisionTreeExt.AtomicRandomForest β Type
Type union for DecisionTree random forest classifiers and regressors.
CounterfactualExplanations.DecisionTreeModel β Method
CounterfactualExplanations.DecisionTreeModel(
model::AtomicDecisionTree; likelihood::Symbol=:classification_binary
)Outer constructor for a decision trees.
CounterfactualExplanations.Generators.TCRExGenerator β Method
(generator::Generators.TCRExGenerator)(
target::RawTargetType,
data::DataPreprocessing.CounterfactualData,
M::Models.AbstractModel
)Applies the Generators.TCRExGenerator to a given target and data using the M model. For details see Bewley et al. (2024) [arXiv, PMLR].
CounterfactualExplanations.Models.Model β Method
(M::Models.Model)(
data::CounterfactualData,
type::CounterfactualExplanations.DecisionTreeModel;
kwargs...,
)Constructs a decision tree for the given data. This method is used internally when a decision-tree model is constructed to be trained from scratch (i.e. no pre-trained model is supplied by the user).
CounterfactualExplanations.Models.Model β Method
(M::Models.Model)(
data::CounterfactualData, type::CounterfactualExplanations.RandomForestModel; kwargs...
)Constructs a random forest for the given data.
CounterfactualExplanations.RandomForestModel β Method
CounterfactualExplanations.RandomForestModel(
model::AtomicRandomForest; likelihood::Symbol=:classification_binary
)Outer constructor for random forests.
CounterfactualExplanations.Generators.incompatible β Method
Generators.incompatible(gen::FeatureTweakGenerator, ce::CounterfactualExplanation)Overloads the incompatible function for the FeatureTweakGenerator.
CounterfactualExplanations.Generators.propose_state β Method
Generators.propose_state(
generator::Generators.FeatureTweakGenerator, ce::AbstractCounterfactualExplanation
)Overloads the Generators.propose_state method for the FeatureTweakGenerator.
DecisionTreeExt.calculate_delta β Method
calculate_delta(ce::AbstractCounterfactualExplanation, penalty::Vector{Function})Calculates the penalty for the proposed feature tweak.
Arguments
ce::AbstractCounterfactualExplanation: The counterfactual explanation object.
Returns
delta::Float64: The calculated penalty for the proposed feature tweak.
DecisionTreeExt.classify_prototypes β Method
classify_prototypes(prototypes, rule_assignments, bounds)Builds the second tree model using the given prototypes as inputs and their corresponding rule_assignments as labels. Split thresholds are restricted to the bounds, which can be computed using partition_bounds(rules). For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.cre β Method
cre(rules, x, X)Computes the counterfactual rule explanations (CRE) for a given point $x$ and a set of $rules$, where the $rules$ correspond to the set of maximal-valid rules for some given target. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.esatisfactory_instance β Method
esatisfactory_instance(generator::FeatureTweakGenerator, x::AbstractArray, paths::Dict{String, Dict{String, Any}})Returns an epsilon-satisfactory counterfactual for x based on the paths provided.
Arguments
generator::FeatureTweakGenerator: The feature tweak generator.x::AbstractArray: The factual instance.paths::Dict{String, Dict{String, Any}}: A list of paths to the leaves of the tree to be used for tweaking the feature.
Returns
esatisfactory::AbstractArray: The epsilon-satisfactory instance.
Example
esatisfactory = esatisfactory_instance(generator, x, paths) # returns an epsilon-satisfactory counterfactual for x based on the paths provided
DecisionTreeExt.extract_leaf_rules β Method
extract_leaf_rules(root::DT.Root)Extracts leaf decision rules (i.e. hyperrectangles) from a decision tree (root). For a decision tree with $L$ leaves this results in $L$ hyperrectangles. The rules are returned as a vector of tuples containing 2-element tuples, where each 2-element tuple stores the lower and upper bound imposed by the given rule for a given feature. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.extract_leaf_rules β Method
extract_leaf_rules(node::Union{DT.Leaf,DT.Node}, conditions::AbstractArray, decisions::AbstractArray)See extract_leaf_rules(root::DT.Root) for details.
DecisionTreeExt.extract_rules β Method
extract_rules(root::DT.Root)Extracts decision rules (i.e. hyperrectangles) from a decision tree (root). For a decision tree with $L$ leaves this results in $2L-1$ hyperrectangles. The rules are returned as a vector of vectors of 2-element tuples, where each tuple stores the lower and upper bound imposed by the given rule for a given feature. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.extract_rules β Method
extract_rules(node::DT.Node, conditions::AbstractArray)DecisionTreeExt.get_individual_classifiers β Method
get_individual_classifiers(M::Model)Returns the individual classifiers in the forest. If the input is a decision tree, the method returns the decision tree itself inside an array.
Arguments
M::Model: The model selected by the user.model::CounterfactualExplanations.D
Returns
classifiers::AbstractArray: An array of individual classifiers in the forest.
DecisionTreeExt.grow_surrogate β Method
grow_surrogate(
generator::Generators.TCRExGenerator, X::AbstractArray, yΜ::AbstractArray
)Grows the tree-based surrogate model for the Generators.TCRExGenerator. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.grow_surrogate β Method
grow_surrogate(
generator::Generators.TCRExGenerator, data::CounterfactualData, M::AbstractModel
)Overloads the grow_surrogate function to accept a CounterfactualData and a AbstractModel to grow a surrogate model. See grow_surrogate(generator::Generators.TCRExGenerator, X::AbstractArray, yΜ::AbstractArray).
DecisionTreeExt.induced_grid β Method
induced_grid(rules)Computes the induced grid of the given rules. For details see Bewley et al. (2024) [arXiv, PMLR]..
DecisionTreeExt.issubrule β Method
issubrule(rule, otherrule)Checks if the rule hyperrectangle is a subset of the otherrule hyperrectangle. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.max_valid β Method
max_valid(rules, X, fx, target, Ο)Returns the maximal-valid rules for a given target and accuracy threshold Ο. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.partition_bounds β Method
partition_bounds(rules, dim::Int)Computes the set of (unique) bounds for each rule in rules along the dim-th dimension. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.partition_bounds β Method
partition_bounds(rules)Computes the set of (unique) bounds for each rule in rules and all dimensions. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.prototype β Method
prototype(rule, X; pick_arbitrary::Bool=true)Picks an arbitrary point $x^C \in X$ (i.e. prototype) from the subet of $X$ that is contained by rule $R_i$. If pick_arbitrary is set to false, the prototype is instead computed as the average across all samples. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_accuracy β Method
rule_accuracy(rule, X, fx, target)Computes the accuracy of the rule on the data X for predicted outputs fx and the target. Accuracy is defined as the fraction of points contained by the rule, for which predicted values match the target. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_changes β Method
rule_changes(rule, x)Computes the number of feature changes necessary for x to be contained by rule $R_i$. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_contains β Method
rule_contains(rule, X)Returns the subet of X that is contained by rule $R_i$. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_cost β Method
rule_cost(rule, x, X)Computes the cost for $x$ to be contained by rule $R_i$, where cost is defined as rule_changes(rule, x) - rule_feasibility(rule, X). For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.rule_feasibility β Method
rule_feasibility(rule, X)Computes the feasibility of a rule $R_i$ for a given dataset. Feasibility is defined as fraction of the data points that satisfy the rule. For details see Bewley et al. (2024) [arXiv, PMLR].
DecisionTreeExt.search_path β Function
search_path(tree::Union{DT.Leaf, DT.Node}, target::RawTargetType, path::AbstractArray)Return a path index list with the inequality symbols, thresholds and feature indices.
Arguments
tree::Union{DT.Leaf, DT.Node}: The root node of a decision tree.target::RawTargetType: The target class.path::AbstractArray: A list containing the paths found thus far.
Returns
paths::AbstractArray: A list of paths to the leaves of the tree to be used for tweaking the feature.
Example
paths = search_path(tree, target) # returns a list of paths to the leaves of the tree to be used for tweaking the feature
DecisionTreeExt.wrap_decision_tree β Function
wrap_decision_tree(node::TreeNode, X, y)Turns a custom decision tree into a DecisionTree.Root object from the DecisionTree.jl package.
DecisionTreeExt.wrap_decision_tree β Method
wrap_decision_tree(node::TreeNode)CounterfactualExplanations.JEM β Method
CounterfactualExplanations.JEM(
model::JointEnergyModels.JointEnergyClassifier; likelihood::Symbol=:classification_multi
)Outer constructor for a neural network with JointEnergyClassifier from JointEnergyModels.jl.
CounterfactualExplanations.Models.Model β Method
Models.Model(model, type::CounterfactualExplanations.JEM; likelihood::Symbol=:classification_multi)Overloaded constructor for Flux models.
CounterfactualExplanations.Models.Model β Method
(M::Model)(data::CounterfactualData, type::JEM; kwargs...)Constructs a differentiable tree-based model for the given data.
CounterfactualExplanations.Models.load_mnist_model β Method
Models.load_mnist_model(type::CounterfactualExplanations.JEM)Overload for loading a pre-trained model for the JEM model type.
CounterfactualExplanations.Models.logits β Method
Models.logits(M::JEM, X::AbstractArray)Calculates the logit scores output by the model M for the input data X.
Arguments
M::JEM: The model selected by the user. Must be a model from the MLJ library.X::AbstractArray: The feature vector for which the logit scores are calculated.
Returns
logits::Matrix: A matrix of logits for each output class for each data point in X.
Example
logits = Models.logits(M, x) # calculates the logit scores for each output class for the data point x
CounterfactualExplanations.Models.probs β Method
Models.probs(
M::Models.Model,
type::CounterfactualExplanations.JEM,
X::AbstractArray,
)Overloads the Models.probs method for NeuroTree models.
CounterfactualExplanations.Models.train β Method
train(M::JEM, data::CounterfactualData; kwargs...)Fits the model M to the data in the CounterfactualData object. This method is not called by the user directly.
Arguments
M::JEM: The wrapper for an JEM model.data::CounterfactualData: TheCounterfactualDataobject containing the data to be used for training the model.
Returns
M::JEM: The fitted JEM model.
CounterfactualExplanations.LaplaceReduxModel β Method
CounterfactualExplanations.LaplaceReduxModel(
model::LaplaceRedux.Laplace; likelihood::Symbol=:classification_binary
)Outer constructor for a neural network with Laplace Approximation from LaplaceRedux.jl.
CounterfactualExplanations.Models.Model β Method
(M::Model)(data::CounterfactualData, type::LaplaceReduxModel; kwargs...)Constructs a differentiable tree-based model for the given data.
CounterfactualExplanations.Models.logits β Method
logits(M::LaplaceReduxModel, X::AbstractArray)Predicts the logit scores for the input data X using the model M.
CounterfactualExplanations.Models.probs β Method
probs(M::LaplaceReduxModel, X::AbstractArray)Predicts the probabilities of the classes for the input data X using the model M.
CounterfactualExplanations.Models.train β Method
train(M::LaplaceReduxModel, data::CounterfactualData; kwargs...)Fits the model M to the data in the CounterfactualData object. This method is not called by the user directly.
Arguments
M::LaplaceReduxModel: The wrapper for an LaplaceReduxModel model.data::CounterfactualData: TheCounterfactualDataobject containing the data to be used for training the model.
Returns
M::LaplaceReduxModel: The fitted LaplaceReduxModel model.