Reference

In this reference, you will find a detailed overview of the package API.

Reference guides are technical descriptions of the machinery and how to operate it. Reference material is information-oriented.

β€” DiΓ‘taxis

In other words, you come here because you want to take a very close look at the code 🧐.

Content

Exported functions

CounterfactualExplanations.CounterfactualExplanation β€” Method
function CounterfactualExplanation(;
	x::AbstractArray,
	target::RawTargetType,
	data::CounterfactualData,
	M::Models.AbstractModel,
	generator::Generators.AbstractGenerator,
	num_counterfactuals::Int = 1,
	initialization::Symbol = :add_perturbation,
    convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
)

Outer method to construct a CounterfactualExplanation structure.

source
CounterfactualExplanations.generate_counterfactual β€” Method
generate_counterfactual(
    x::Base.Iterators.Zip,
    target::RawTargetType,
    data::CounterfactualData,
    M::Models.AbstractModel,
    generator::AbstractGenerator;
    kwargs...,
)

Overloads the generate_counterfactual method to accept a zip of factuals x and return a vector of counterfactuals.

source
CounterfactualExplanations.generate_counterfactual β€” Method
generate_counterfactual(
    x::Matrix,
    target::RawTargetType,
    data::CounterfactualData,
    M::Models.AbstractModel,
    generator::AbstractGenerator;
    num_counterfactuals::Int=1,
    initialization::Symbol=:add_perturbation,
    convergence::Union{AbstractConvergence,Symbol}=:decision_threshold,
    timeout::Union{Nothing,Real}=nothing,
)

The core function that is used to run counterfactual search for a given factual x, target, counterfactual data, model and generator. Keywords can be used to specify the desired threshold for the predicted target class probability and the maximum number of iterations.

Arguments

  • x::Matrix: Factual data point.
  • target::RawTargetType: Target class.
  • data::CounterfactualData: Counterfactual data.
  • M::Models.AbstractModel: Fitted model.
  • generator::AbstractGenerator: Generator.
  • num_counterfactuals::Int=1: Number of counterfactuals to generate for factual.
  • initialization::Symbol=:add_perturbation: Initialization method. By default, the initialization is done by adding a small random perturbation to the factual to achieve more robustness.
  • convergence::Union{AbstractConvergence,Symbol}=:decision_threshold: Convergence criterion. By default, the convergence is based on the decision threshold. Possible values are :decision_threshold, :max_iter, :generator_conditions or a conrete convergence object (e.g. DecisionThresholdConvergence).
  • timeout::Union{Nothing,Int}=nothing: Timeout in seconds.

Examples

Generic generator

julia> using CounterfactualExplanations

julia> using TaijaData
       
        # Counteractual data and model:

julia> counterfactual_data = CounterfactualData(load_linearly_separable()...);

julia> M = fit_model(counterfactual_data, :Linear);

julia> target = 2;

julia> factual = 1;

julia> chosen = rand(findall(predict_label(M, counterfactual_data) .== factual));

julia> x = select_factual(counterfactual_data, chosen);
       
       # Search:

julia> generator = Generators.GenericGenerator();

julia> ce = generate_counterfactual(x, target, counterfactual_data, M, generator);

julia> converged(ce.convergence, ce)
true

Broadcasting

The generate_counterfactual method can also be broadcasted over a tuple containing an array. This allows for generating multiple counterfactuals in parallel.

julia> chosen = rand(findall(predict_label(M, counterfactual_data) .== factual), 5);

julia> xs = select_factual(counterfactual_data, chosen);

julia> ces = generate_counterfactual.(xs, target, counterfactual_data, M, generator);

julia> converged(ce.convergence, ce)
true
source
CounterfactualExplanations.generate_counterfactual β€” Method
generate_counterfactual(
    x::Matrix,
    target::RawTargetType,
    data::DataPreprocessing.CounterfactualData,
    M::Models.AbstractModel,
    generator::Generators.GrowingSpheresGenerator;
    num_counterfactuals::Int=1,
    convergence::Union{AbstractConvergence,Symbol}=Convergence.DecisionThresholdConvergence(;
        decision_threshold=(1 / length(data.y_levels)), max_iter=1000
    ),
    kwrgs...,
)

Overloads the generate_counterfactual for the GrowingSpheresGenerator generator.

source
CounterfactualExplanations.generate_counterfactual β€” Method
generate_counterfactual(
    x::Vector{<:Matrix},
    target::RawTargetType,
    data::CounterfactualData,
    M::Models.AbstractModel,
    generator::AbstractGenerator;
    kwargs...,
)

Overloads the generate_counterfactual method to accept a vector of factuals x and return a vector of counterfactuals.

source
CounterfactualExplanations.target_probs β€” Function
target_probs(
    ce::CounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Returns the predicted probability of the target class for x. If x is nothing, the predicted probability corresponding to the counterfactual value is returned.

source
CounterfactualExplanations.Convergence.DecisionThresholdConvergence β€” Type
DecisionThresholdConvergence

Convergence criterion based on the target class probability threshold. The search stops when the target class probability exceeds the predefined threshold.

Fields

  • decision_threshold::AbstractFloat: The predefined threshold for the target class probability.
  • max_iter::Int: The maximum number of iterations.
  • min_success_rate::AbstractFloat: The minimum success rate for the target class probability.
source
CounterfactualExplanations.Convergence.GeneratorConditionsConvergence β€” Type
GeneratorConditionsConvergence

Convergence criterion for counterfactual explanations based on the generator conditions. The search stops when the gradients of the search objective are below a certain threshold and the generator conditions are satisfied.

Fields

  • decision_threshold::AbstractFloat: The threshold for the decision probability.
  • gradient_tol::AbstractFloat: The tolerance for the gradients of the search objective.
  • max_iter::Int: The maximum number of iterations.
  • min_success_rate::AbstractFloat: The minimum success rate for the generator conditions (across counterfactuals).
source
CounterfactualExplanations.Convergence.converged β€” Function
converged(
    convergence::InvalidationRateConvergence,
    ce::AbstractCounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Checks if the counterfactual search has converged when the convergence criterion is invalidation rate.

source
CounterfactualExplanations.Convergence.converged β€” Function
converged(
    convergence::DecisionThresholdConvergence,
    ce::AbstractCounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Checks if the counterfactual search has converged when the convergence criterion is the decision threshold.

source
CounterfactualExplanations.Convergence.converged β€” Function
converged(
    convergence::MaxIterConvergence,
    ce::AbstractCounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Checks if the counterfactual search has converged when the convergence criterion is maximum iterations. This means the counterfactual search will not terminate until the maximum number of iterations has been reached independently of the other convergence criteria.

source
CounterfactualExplanations.Convergence.converged β€” Function
converged(
    convergence::GeneratorConditionsConvergence,
    ce::AbstractCounterfactualExplanation,
    x::Union{AbstractArray,Nothing}=nothing,
)

Checks if the counterfactual search has converged when the convergence criterion is generator_conditions.

source
CounterfactualExplanations.Convergence.invalidation_rate β€” Method
invalidation_rate(ce::AbstractCounterfactualExplanation)

Calculates the invalidation rate of a counterfactual explanation.

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation to calculate the invalidation rate for.
  • kwargs: Additional keyword arguments to pass to the function.

Returns

The invalidation rate of the counterfactual explanation.

source
CounterfactualExplanations.Evaluation.benchmark β€” Method
benchmark(
    data::CounterfactualData;
    models::Dict{<:Any,<:Any}=standard_models_catalogue,
    generators::Union{Nothing,Dict{<:Any,<:AbstractGenerator}}=nothing,
    measure::Union{Function,Vector{Function}}=default_measures,
    n_individuals::Int=5,
    suppress_training::Bool=false,
    factual::Union{Nothing,RawTargetType}=nothing,
    target::Union{Nothing,RawTargetType}=nothing,
    store_ce::Bool=false,
    parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
    kwrgs...,
)

Runs the benchmarking exercise as follows:

  1. Randomly choose a factual and target label unless specified.
  2. If no pretrained models are provided, it is assumed that a dictionary of callable model objects is provided (by default using the standard_models_catalogue).
  3. Each of these models is then trained on the data.
  4. For each model separately choose n_individuals randomly from the non-target (factual) class. For each generator create a benchmark as in benchmark(xs::Union{AbstractArray,Base.Iterators.Zip}).
  5. Finally, concatenate the results.

If vertical_splits is specified to an integer, the computations are split vertically into vertical_splits chunks. In this case, the results are stored in a temporary directory and concatenated afterwards.

source
CounterfactualExplanations.Evaluation.benchmark β€” Method
benchmark(
    x::Union{AbstractArray,Base.Iterators.Zip},
    target::RawTargetType,
    data::CounterfactualData;
    models::Dict{<:Any,<:AbstractModel},
    generators::Dict{<:Any,<:AbstractGenerator},
    measure::Union{Function,Vector{Function}}=default_measures,
    xids::Union{Nothing,AbstractArray}=nothing,
    dataname::Union{Nothing,Symbol,String}=nothing,
    verbose::Bool=true,
    store_ce::Bool=false,
    parallelizer::Union{Nothing,AbstractParallelizer}=nothing,
    kwrgs...,
)

First generates counterfactual explanations for factual x, the target and data using each of the provided models and generators. Then generates a Benchmark for the vector of counterfactual explanations as in benchmark(counterfactual_explanations::Vector{CounterfactualExplanation}).

source
CounterfactualExplanations.Evaluation.benchmark β€” Method
benchmark(
    counterfactual_explanations::Vector{CounterfactualExplanation};
    meta_data::Union{Nothing,<:Vector{<:Dict}}=nothing,
    measure::Union{Function,Vector{Function}}=default_measures,
    store_ce::Bool=false,
)

Generates a Benchmark for a vector of counterfactual explanations. Optionally meta_data describing each individual counterfactual explanation can be supplied. This should be a vector of dictionaries of the same length as the vector of counterfactuals. If no meta_data is supplied, it will be automatically inferred. All measure functions are applied to each counterfactual explanation. If store_ce=true, the counterfactual explanations are stored in the benchmark.

source
CounterfactualExplanations.Evaluation.evaluate β€” Function
evaluate(
    ce::CounterfactualExplanation;
    measure::Union{Function,Vector{Function}}=default_measures,
    agg::Function=mean,
    report_each::Bool=false,
    output_format::Symbol=:Vector,
    pivot_longer::Bool=true
)

Just computes evaluation measures for the counterfactual explanation. By default, no meta data is reported. For report_meta=true, meta data is automatically inferred, unless this overwritten by meta_data. The optional meta_data argument should be a vector of dictionaries of the same length as the vector of counterfactual explanations.

source
CounterfactualExplanations.Evaluation.faithfulness β€” Method
faithfulness(
    ce::CounterfactualExplanation,
    fun::typeof(Objectives.distance_from_target);
    Ξ»::AbstractFloat=1.0,
    kwrgs...,
)

Computes the faithfulness of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the model posterior through SGLD (see distance_from_posterior).

source
CounterfactualExplanations.Evaluation.plausibility β€” Method
plausibility(
    ce::CounterfactualExplanation,
    fun::typeof(Objectives.distance_from_target);
    K=nothing,
    kwrgs...,
)

Computes the plausibility of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the target distribution.

source
CounterfactualExplanations.Evaluation.plausibility β€” Method
plausibility(
    ce::CounterfactualExplanation,
    fun::typeof(Objectives.distance_from_target);
    K=nothing,
    kwrgs...,
)

Computes the plausibility of a counterfactual explanation based on the cosine similarity between the counterfactual and samples drawn from the target distribution.

source
CounterfactualExplanations.Evaluation.validity β€” Method
validity(ce::CounterfactualExplanation; Ξ³=0.5)

Checks of the counterfactual search has been successful with respect to the probability threshold Ξ³. In case multiple counterfactuals were generated, the function returns the proportion of successful counterfactuals.

source
CounterfactualExplanations.DataPreprocessing.CounterfactualData β€” Method
CounterfactualData(
    X::AbstractMatrix,
    y::RawOutputArrayType;
    mutability::Union{Vector{Symbol},Nothing}=nothing,
    domain::Union{Any,Nothing}=nothing,
    features_categorical::Union{Vector{Vector{Int}},Nothing}=nothing,
    features_continuous::Union{Vector{Int},Nothing}=nothing,
    input_encoder::Union{Nothing,InputTransformer,TypedInputTransformer}=nothing,
)

This outer constructor method prepares features X and labels y to be used with the package. Mutability and domain constraints can be added for the features. The function also accepts arguments that specify which features are categorical and which are continues. These arguments are currently not used.

Examples

using CounterfactualExplanations.Data
x, y = toy_data_linear()
X = hcat(x...)
counterfactual_data = CounterfactualData(X,y')
source
CounterfactualExplanations.Models.Model β€” Method
Model(model, type::AbstractModelType; likelihood::Symbol=:classification_binary)

Outer constructor for Model where the atomic model is defined and assumed to be pre-trained.

source
CounterfactualExplanations.Models.Model β€” Method
(M::Model)(data::CounterfactualData, type::Linear; kwargs...)

Constructs a model with one linear layer for the given data. If the output is binary, this corresponds to logistic regression, since model outputs are passed through the sigmoid function. If the output is multi-class, this corresponds to multinomial logistic regression, since model outputs are passed through the softmax function.

source
CounterfactualExplanations.Models.fit_model β€” Function
fit_model(
    counterfactual_data::CounterfactualData, model::Symbol=:MLP;
    kwrgs...
)

Fits one of the available default models to the counterfactual_data. The model argument can be used to specify the desired model. The available values correspond to the keys of the all_models_catalogue dictionary.

source
CounterfactualExplanations.Models.fit_model β€” Method
fit_model(
    counterfactual_data::CounterfactualData, type::AbstractModelType; kwrgs...
)

A wrapper function to fit a model to the counterfactual_data for a given type of model.

Arguments

  • counterfactual_data::CounterfactualData: The data to be used for training the model.
  • type::AbstractModelType: The type of model to be trained, e.g., MLP, DecisionTreeModel, etc.

Examples

julia> using CounterfactualExplanations

julia> using CounterfactualExplanations.Models

julia> using TaijaData

julia> data = CounterfactualData(load_linearly_separable()...);

julia> M = fit_model(data, Linear())
CounterfactualExplanations.Models.Model(Chain(Dense(2 => 2)), :classification_multi, CounterfactualExplanations.Models.Fitresult(Chain(Dense(2 => 2)), Dict{Any, Any}()), Linear())
source
CounterfactualExplanations.Models.model_evaluation β€” Method
model_evaluation(M::AbstractModel, test_data::CounterfactualData)

Helper function to compute F-Score for AbstractModel on a (test) data set. By default, it computes the accuracy. Any other measure, e.g. from the StatisticalMeasures package, can be passed as an argument. Currently, only measures applicable to classification tasks are supported.

source
CounterfactualExplanations.Models.predict_proba β€” Method
predict_proba(M::AbstractModel, counterfactual_data::CounterfactualData, X::Union{Nothing,AbstractArray})

Returns the predicted output probabilities for a given model M, data set counterfactual_data and input data X.

source
CounterfactualExplanations.Models.probs β€” Method
probs(
    M::Model,
    type::MLJModelType,
    X::AbstractArray,
)

Overloads the probs method for MLJ models.

Note for developers

Note that currently the underlying MLJ methods (reformat, predict) are incompatible with Zygote's autodiff. For differentiable MLJ models, the probs` and logits methods need to be overloaded.

source
CounterfactualExplanations.Generators.FeatureTweakGenerator β€” Method
FeatureTweakGenerator(; penalty::Union{Nothing,Function,Vector{Function}}=Objectives.distance_l2, Ο΅::AbstractFloat=0.1)

Constructs a new Feature Tweak Generator object.

Uses the L2-norm as the penalty to measure the distance between the counterfactual and the factual. According to the paper by Tolomei et al., another recommended choice for the penalty in addition to the L2-norm is the L0-norm. The L0-norm simply minimizes the number of features that are changed through the tweak.

Arguments

  • penalty::Penalty: The penalty function to use for the generator. Defaults to distance_l2.
  • Ο΅::AbstractFloat: The tolerance value for the feature tweaks. Described at length in Tolomei et al. (https://arxiv.org/pdf/1706.06691.pdf). Defaults to 0.1.

Returns

  • generator::FeatureTweakGenerator: A non-gradient-based generator that can be used to generate counterfactuals using the feature tweak method.
source
CounterfactualExplanations.Generators.GradientBasedGenerator β€” Method
GradientBasedGenerator(;
	loss::Union{Nothing,Function}=nothing,
	penalty::Penalty=nothing,
	Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing,
	latent_space::Bool::false,
	opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(),
    generative_model_params::NamedTuple=(;),
)

Default outer constructor for GradientBasedGenerator.

Arguments

  • loss::Union{Nothing,Function}=nothing: The loss function used by the model.
  • penalty::Penalty=nothing: A penalty function for the generator to penalize counterfactuals too far from the original point.
  • Ξ»::Union{Nothing,AbstractFloat,Vector{AbstractFloat}}=nothing: The weight of the penalty function.
  • latent_space::Bool=false: Whether to use the latent space of a generative model to generate counterfactuals.
  • opt::Flux.Optimise.AbstractOptimiser=Flux.Descent(): The optimizer to use for the generator.
  • generative_model_params::NamedTuple: The parameters of the generative model associated with the generator.

Returns

  • generator::GradientBasedGenerator: A gradient-based counterfactual generator.
source
CounterfactualExplanations.Generators.ProbeGenerator β€” Method

Constructor for ProbeGenerator. For details, see Pawelczyk et al. (2022).

Warning

The ProbeGenerator is currenlty not working adequately. In particular, gradients are not computed with respect to the Hinge loss term proposed in the paper. It is still possible, however, to use this generator to achieve a desired invalidation rate. See issue #376 for details.

source
CounterfactualExplanations.Objectives.distance β€” Method
distance(
    ce::AbstractCounterfactualExplanation;
    from::Union{Nothing,AbstractArray}=nothing,
    agg=mean,
    p::Real=1,
    weights::Union{Nothing,AbstractArray}=nothing,
)

Computes the distance of the counterfactual to the original factual.

source
CounterfactualExplanations.Objectives.hinge_loss β€” Method
hinge_loss(ce::AbstractCounterfactualExplanation)

Calculates the hinge loss of a counterfactual explanation with InvalidationRateConvergence.

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation to calculate the hinge loss for.

Returns

The hinge loss of the counterfactual explanation.

source
Flux.Losses.logitbinarycrossentropy β€” Method
Flux.Losses.logitbinarycrossentropy(ce::AbstractCounterfactualExplanation)

Simply extends the logitbinarycrossentropy method to work with objects of type AbstractCounterfactualExplanation.

source
Flux.Losses.logitcrossentropy β€” Method
Flux.Losses.logitcrossentropy(ce::AbstractCounterfactualExplanation)

Simply extends the logitcrossentropy method to work with objects of type AbstractCounterfactualExplanation.

source
Flux.Losses.mse β€” Method
Flux.Losses.mse(ce::AbstractCounterfactualExplanation)

Simply extends the mse method to work with objects of type AbstractCounterfactualExplanation.

source

Internal functions

CounterfactualExplanations.CRE β€” Type
CRE <: AbstractCounterfactualExplanation

A Counterfactual Rule Explanation (CRE) is a global explanation for a given target, model M, data and generator.

source
CounterfactualExplanations.JEM β€” Type
JEM

Concrete type for joint-energy models from JointEnergyModels. Since JointEnergyModels has an MLJ interface, we subtype the MLJModelType model type.

source
CounterfactualExplanations.LaplaceReduxModel β€” Type
LaplaceReduxModel

Concrete type for neural networks with Laplace Approximation from the LaplaceRedux package. Currently subtyping the AbstractFluxNN model type, although this may be changed to MLJ in the future.

source
CounterfactualExplanations.apply_mutability β€” Method
apply_mutability(
    ce::CounterfactualExplanation,
    Ξ”counterfactual_state::AbstractArray,
)

A subroutine that applies mutability constraints to the proposed vector of feature perturbations.

source
CounterfactualExplanations.decode_array β€” Method
decode_array(
    data::CounterfactualData,
    dt::CausalInference.SCM,
    x::AbstractArray,
)

Helper function to decode an array x using a data transform dt::GenerativeModels.AbstractGenerativeModel.

source
CounterfactualExplanations.decode_array β€” Method
decode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)

Helper function to decode an array x using a data transform dt::GenerativeModels.AbstractGenerativeModel.

source
CounterfactualExplanations.decode_array β€” Method
decode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)

Helper function to decode an array x using a data transform dt::MultivariateStats.AbstractDimensionalityReduction.

source
CounterfactualExplanations.decode_state β€” Function

function decode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing}=nothing, )

Applies all the applicable decoding functions:

  1. If applicable, map the state variable back from the latent space to the feature space.
  2. If and where applicable, inverse-transform features.
  3. Reconstruct all categorical encodings.

Finally, the decoded counterfactual is returned.

source
CounterfactualExplanations.encode_array β€” Method
encode_array(data::CounterfactualData, dt::CausalInference.SCM, x::AbstractArray)

Helper function to encode an array x using a data transform dt::CausalInference.SCM. This is a no-op.

source
CounterfactualExplanations.encode_array β€” Method
encode_array(dt::GenerativeModels.AbstractGenerativeModel, x::AbstractArray)

Helper function to encode an array x using a data transform dt::GenerativeModels.AbstractGenerativeModel.

source
CounterfactualExplanations.encode_array β€” Method
encode_array(dt::MultivariateStats.AbstractDimensionalityReduction, x::AbstractArray)

Helper function to encode an array x using a data transform dt::MultivariateStats.AbstractDimensionalityReduction.

source
CounterfactualExplanations.encode_state β€” Function

function encode_state( ce::CounterfactualExplanation, x::Union{AbstractArray,Nothing} = nothing, )

Applies all required encodings to x:

  1. If applicable, it maps x to the latent space learned by the generative model.
  2. If and where applicable, it rescales features.

Finally, it returns the encoded state variable.

source
CounterfactualExplanations.guess_likelihood β€” Method
guess_likelihood(y::RawOutputArrayType)

Guess the likelihood based on the scientific type of the output array. Returns a symbol indicating the guessed likelihood and the scientific type of the output array.

source
CounterfactualExplanations.initialize! β€” Method
initialize!(ce::CounterfactualExplanation)

Initializes the counterfactual explanation. This method is called by the constructor. It does the following:

  1. Creates a dictionary to store information about the search.
  2. Initializes the counterfactual state.
  3. Initializes the search path.
  4. Initializes the loss.
source
CounterfactualExplanations.initialize_state β€” Method
initialize_state(ce::CounterfactualExplanation)

Initializes the starting point for the factual(s):

  1. If ce.initialization is set to :identity or counterfactuals are searched in a latent space, then nothing is done.
  2. If ce.initialization is set to :add_perturbation, then a random perturbation is added to the factual following following Slack (2021): https://arxiv.org/abs/2106.02666. The authors show that this improves adversarial robustness.
source
CounterfactualExplanations.polynomial_decay β€” Method
polynomial_decay(a::Real, b::Real, decay::Real, t::Int)

Computes the polynomial decay function as in Welling et al. (2011): https://www.stats.ox.ac.uk/~teh/research/compstats/WelTeh2011a.pdf.

source
CounterfactualExplanations.update! β€” Method
update!(ce::CounterfactualExplanation)

An important subroutine that updates the counterfactual explanation. It takes a snapshot of the current counterfactual search state and passes it to the generator. Based on the current state the generator generates perturbations. Various constraints are then applied to the proposed vector of feature perturbations. Finally, the counterfactual search state is updated.

source
CounterfactualExplanations.Evaluation.EnergySampler β€” Method
EnergySampler(
    model::AbstractModel,
    π’Ÿx::Distribution,
    π’Ÿy::Distribution,
    input_size::Dims,
    yidx::Int;
    opt::Union{Nothing,AbstractSamplingRule}=nothing,
    nsamples::Int=100,
    niter_final::Int=1000,
    ntransitions::Int=0,
    opt_warmup::Union{Nothing,AbstractSamplingRule}=nothing,
    niter::Int=20,
    batch_size::Int=50,
    prob_buffer::AbstractFloat=0.95,
    kwargs...,
)

Constructor for EnergySampler, which is used to sample from the posterior distribution of the model conditioned on y.

Arguments

  • model::AbstractModel: The model to be used for sampling.
  • data::CounterfactualData: The data to be used for sampling.
  • y::Any: The conditioning value.
  • opt::AbstractSamplingRule=ImproperSGLD(): The sampling rule to be used. By default, SGLD is used with a = (2 / std(Uniform()) * std(π’Ÿx) and b = 1 and Ξ³=0.9.
  • nsamples::Int=100: The number of samples to include in the final empirical posterior distribution.
  • niter_final::Int=1000: The number of iterations for generating samples from the posterior distribution. Typically, this number will be larger than the number of iterations during PMC training.
  • ntransitions::Int=0: The number of transitions for (optionally) warming up the sampler. By default, this is set to 0 and the sampler is not warmed up. For valies larger than 0, the sampler is trained through PMC for niter iterations and ntransitions transitions to build a buffer of samples. The buffer is used for posterior sampling.
  • opt_warmup::Union{Nothing,AbstractSamplingRule}=nothing: The sampling rule to be used for warm-up. By default, ImproperSGLD is used with Ξ± = (2 / std(Uniform()) * std(π’Ÿx) and Ξ³ = 0.005Ξ±.
  • niter::Int=100: The number of iterations for training the sampler through PMC.
  • batch_size::Int=50: The batch size for training the sampler.
  • prob_buffer::AbstractFloat=0.5: The probability of drawing samples from the replay buffer. Smaller values will result in more samples being drawn from the prior and typically lead to better mixing and diversity in the samples.
  • kwargs...: Additional keyword arguments to be passed on to the sampler and PMC.

Returns

  • EnergySampler: An instance of EnergySampler.
source
Base.rand β€” Function
Base.rand(sampler::EnergySampler, n::Int=100; retrain=false)

Overloads the rand method to randomly draw n samples from EnergySampler. If from_posterior is true, the samples are drawn from the posterior distribution. Otherwise, the samples are generated from the model conditioned on the target value using a single chain (see generate_posterior_samples).

Arguments

  • sampler::EnergySampler: The EnergySampler object to be used for sampling.
  • n::Int=100: The number of samples to draw.
  • from_posterior::Bool=true: Whether to draw samples from the posterior distribution.
  • niter::Int=500: The number of iterations for generating samples through Monte Carlo sampling (single chain).

Returns

  • AbstractArray: The samples.
source
Base.vcat β€” Method
Base.vcat(bmk1::Benchmark, bmk2::Benchmark)

Vertically concatenates two Benchmark objects.

source
CounterfactualExplanations.Evaluation.compute_measure β€” Method
compute_measure(ce::CounterfactualExplanation, measure::Function, agg::Function)

Computes a single measure for a counterfactual explanation. The measure is applied to the counterfactual explanation ce and aggregated using the aggregation function agg.

source
CounterfactualExplanations.Evaluation.define_prior β€” Method
define_prior(
    data::CounterfactualData;
    π’Ÿx::Union{Nothing,Distribution}=nothing,
    π’Ÿy::Union{Nothing,Distribution}=nothing,
    n_std::Int=3,
)

Defines the prior for the data. The space is defined as a uniform distribution with bounds defined by the mean and standard deviation of the data. The bounds are extended by n_std standard deviations.

Arguments

  • data::CounterfactualData: The data to be used for defining the prior sampling space.
  • n_std::Int=3: The number of standard deviations to extend the bounds.

Returns

  • Uniform: The uniform distribution defining the prior sampling space.
source
CounterfactualExplanations.Evaluation.distance_from_posterior β€” Method
distance_from_posterior(ce::AbstractCounterfactualExplanation)

Computes the distance from the counterfactual to generated conditional samples. The distance is computed as the mean distance from the counterfactual to the samples drawn from the posterior distribution of the model. By default, the cosine distance is used.

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation object.
  • nsamples::Int=1000: The number of samples to draw.
  • from_posterior::Bool=true: Whether to draw samples from the posterior distribution.
  • agg: The aggregation function to use for computing the distance.
  • choose_lowest_energy::Bool=true: Whether to choose the samples with the lowest energy.
  • choose_random::Bool=false: Whether to choose random samples.
  • nmin::Int=25: The minimum number of samples to choose.
  • p::Int=1: The norm to use for computing the distance.
  • cosine::Bool=true: Whether to use the cosine distance.
  • kwargs...: Additional keyword arguments to be passed on to the EnergySampler.

Returns

  • AbstractFloat: The distance from the counterfactual to the samples.
source
CounterfactualExplanations.Evaluation.generate_posterior_samples β€” Function
generate_posterior_samples(
    e::EnergySampler, n::Int=1000; niter::Int=1000, kwargs...
)

Generates n samples from the posterior distribution of the model conditioned on the target value y. The samples are generated through (Persistent) Monte Carlo sampling using the EnergySampler object. If the replay buffer is not empty, the initial samples are drawn from the buffer.

Note that by default the batch size of the sampler is set to round(Int, n / 100) by default for sampling. This is to ensure that the samples are drawn independently from the posterior distribution. It also helps to avoid vanishing gradients.

The chain is run persistently until n samples are generated. The number of transitions is set to ceil(Int, n / batch_size). Once the chain is run, the last n samples are form the replay buffer are returned.

Arguments

  • e::EnergySampler: The EnergySampler object to be used for sampling.
  • n::Int=100: The number of samples to generate.
  • batch_size::Int=round(Int, n / 100): The batch size for sampling.
  • niter::Int=1000: The number of iterations for generating samples from the posterior distribution.
  • kwargs...: Additional keyword arguments to be passed on to the sampler.

Returns

  • AbstractArray: The generated samples.
source
CounterfactualExplanations.Evaluation.get_lowest_energy_sample β€” Method
get_lowest_energy_sample(sampler::EnergySampler; n::Int=5)

Chooses the samples with the lowest energy (i.e. highest probability) from EnergySampler.

Arguments

  • sampler::EnergySampler: The EnergySampler object to be used for sampling.
  • n::Int=5: The number of samples to choose.

Returns

  • AbstractArray: The samples with the lowest energy.
source
CounterfactualExplanations.Evaluation.get_sampler! β€” Method
get_sampler!(ce::AbstractCounterfactualExplanation; kwargs...)

Gets the EnergySampler object from the counterfactual explanation. If the sampler is not found, it is constructed and stored in the counterfactual explanation object.

source
CounterfactualExplanations.Evaluation.to_dataframe β€” Method
evaluate_dataframe(
    ce::CounterfactualExplanation,
    measure::Vector{Function},
    agg::Function,
    report_each::Bool,
    pivot_longer::Bool,
    store_ce::Bool,
)

Evaluates a counterfactual explanation and returns a dataframe of evaluation measures.

source
CounterfactualExplanations.Evaluation.warmup! β€” Method
warmup!(
    e::EnergySampler,
    y::Int;
    niter::Int=20,
    ntransitions::Int=100,
    kwargs...,
)

Warms up the EnergySampler to the underlying model for conditioning value y. Specifically, this entails running PMC for niter iterations and ntransitions transitions to build a buffer of samples. The buffer is used for posterior sampling.

Arguments

  • e::EnergySampler: The EnergySampler object to be trained.
  • y::Int: The conditioning value.
  • opt::Union{Nothing,AbstractSamplingRule}: The sampling rule to be used. By default, ImproperSGLD is used with Ξ± = 2 * std(Uniform(π’Ÿx)) and Ξ³ = 0.005Ξ±.
  • niter::Int=20: The number of iterations for training the sampler through PMC.
  • ntransitions::Int=100: The number of transitions for training the sampler. In each transition, the sampler is updated with a mini-batch of data. Data is either drawn from the replay buffer or reinitialized from the prior.
  • kwargs...: Additional keyword arguments to be passed on to the sampler and PMC.

Returns

  • EnergySampler: The trained EnergySampler.
source
CounterfactualExplanations.DataPreprocessing.InputTransformer β€” Type
InputTransformer

Abstract type for data transformers. This can be any of the following:

  • StatsBase.AbstractDataTransform: A data transformation object from the StatsBase package.
  • MultivariateStats.AbstractDimensionalityReduction: A dimensionality reduction object from the MultivariateStats package.
  • GenerativeModels.AbstractGenerativeModel: A generative model object from the GenerativeModels module.
source
CounterfactualExplanations.DataPreprocessing.convert_to_1d β€” Method
convert_to_1d(y::Matrix, y_levels::AbstractArray)

Helper function to convert a one-hot encoded matrix to a vector of labels. This is necessary because MLJ models require the labels to be represented as a vector, but the synthetic datasets in this package hold the labels in one-hot encoded form.

Arguments

  • y::Matrix: The one-hot encoded matrix.
  • y_levels::AbstractArray: The levels of the categorical variable.

Returns

  • labels: A vector of labels.
source
CounterfactualExplanations.DataPreprocessing.preprocess_data_for_mlj β€” Method
preprocess_data_for_mlj(data::CounterfactualData)

Helper function to preprocess data::CounterfactualData for MLJ models.

Arguments

  • data::CounterfactualData: The data to be preprocessed.

Returns

  • (df_x, y): A tuple containing the preprocessed data, with df_x being a DataFrame object and y being a categorical vector.

Example

X, y = preprocessdatafor_mlj(data)

source
CounterfactualExplanations.DataPreprocessing.train_test_split β€” Method
train_test_split(data::CounterfactualData;test_size=0.2,keep_class_ratio=false)

Splits data into train and test split.

Arguments

  • data::CounterfactualData: The data to be preprocessed.
  • test_size=0.2: Proportion of the data to be used for testing.
  • keep_class_ratio=false: Decides whether to sample equally from each class, or keep their relative size.

Returns

  • (train_data::CounterfactualData, test_data::CounterfactualData): A tuple containing the train and test splits.

Example

train, test = traintestsplit(data, testsize=0.1, keepclass_ratio=true)

source
CounterfactualExplanations.Models.Fitresult β€” Type
Fitresult

A struct to hold the results of fitting a model.

Fields

  • fitresult: The result of fitting the model to the data. This object should be callable on new data.
  • other::Dict: A dictionary to hold any other relevant information.
source
CounterfactualExplanations.GenerativeModels.VAE β€” Type
VAE <: AbstractGenerativeModel

Constructs the Variational Autoencoder. The VAE is a subtype of AbstractGenerativeModel. Any (sub-)type of AbstractGenerativeModel is accepted by latent space generators.

source
Base.rand β€” Function

Random.rand(encoder::Encoder, x, device=cpu)

Draws random samples from the latent distribution.

source
CounterfactualExplanations.Convergence.conditions_satisfied β€” Method
Convergence.conditions_satisfied(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)

The default method to check if the all conditions for convergence of the counterfactual search have been satisified for gradient-based generators. By default, gradient-based search is considered to have converged as soon as the proposed feature changes for all features are smaller than one percent of its standard deviation.

source
CounterfactualExplanations.Generators.feature_selection! β€” Method
feature_selection!(ce::AbstractCounterfactualExplanation)

Perform feature selection to find the dimension with the closest (but not equal) values between the ce.factual (factual) and ce.counterfactual_state (counterfactual) arrays.

Arguments

  • ce::AbstractCounterfactualExplanation: An instance of the AbstractCounterfactualExplanation type representing the counterfactual explanation.

Returns

  • nothing

The function iteratively modifies the ce.counterfactual_state counterfactual array by updating its elements to match the corresponding elements in the ce.factual factual array, one dimension at a time, until the predicted label of the modified ce.counterfactual_state matches the predicted label of the ce.factual array.

source
CounterfactualExplanations.Generators.find_closest_dimension β€” Method
find_closest_dimension(factual, counterfactual)

Find the dimension with the closest (but not equal) values between the factual and counterfactual arrays.

Arguments

  • factual: The factual array.
  • counterfactual: The counterfactual array.

Returns

  • closest_dimension: The index of the dimension with the closest values.

The function iterates over the indices of the factual array and calculates the absolute difference between the corresponding elements in the factual and counterfactual arrays. It returns the index of the dimension with the smallest difference, excluding dimensions where the values in factual and counterfactual are equal.

source
CounterfactualExplanations.Generators.find_counterfactual β€” Method
find_counterfactual(model, factual_class, counterfactual_data, counterfactual_candidates)

Find the first counterfactual index by predicting labels.

Arguments

  • model: The fitted model used for prediction.
  • target_class: Expected target class.
  • counterfactual_data: Data required for counterfactual generation.
  • counterfactual_candidates: The array of counterfactual candidates.

Returns

  • counterfactual: The index of the first counterfactual found.
source
CounterfactualExplanations.Generators.growing_spheres_generation! β€” Method
growing_spheres_generation(ce::AbstractCounterfactualExplanation)

Generate counterfactual candidates using the growing spheres generation algorithm.

Arguments

  • ce::AbstractCounterfactualExplanation: An instance of the AbstractCounterfactualExplanation type representing the counterfactual explanation.

Returns

  • nothing

This function applies the growing spheres generation algorithm to generate counterfactual candidates. It starts by generating random points uniformly on a sphere, gradually reducing the search space until no counterfactuals are found. Then it expands the search space until at least one counterfactual is found or the maximum number of iterations is reached.

The algorithm iteratively generates counterfactual candidates and predicts their labels using the model stored in ce.M. It checks if any of the predicted labels are different from the factual class. The process of reducing the search space involves halving the search radius, while the process of expanding the search space involves increasing the search radius.

source
CounterfactualExplanations.Generators.h β€” Method
h(generator::AbstractGenerator, penalty::Function, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where a single penalty function is provided.

source
CounterfactualExplanations.Generators.h β€” Method
h(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where a single penalty function is provided with additional keyword arguments.

source
CounterfactualExplanations.Generators.h β€” Method
h(generator::AbstractGenerator, penalty::Tuple, ce::AbstractCounterfactualExplanation)

Overloads the h function for the case where a single penalty function is provided with additional keyword arguments.

source
CounterfactualExplanations.Generators.hyper_sphere_coordinates β€” Method
hyper_sphere_coordinates(n_search_samples::Int, instance::Vector{Float64}, low::Int, high::Int; p_norm::Int=2)

Generates candidate counterfactuals using the growing spheres method based on hyper-sphere coordinates.

The implementation follows the Random Point Picking over a sphere algorithm described in the paper: "Learning Counterfactual Explanations for Tabular Data" by Pawelczyk, Broelemann & Kascneci (2020), presented at The Web Conference 2020 (WWW). It ensures that points are sampled uniformly at random using insights from: http://mathworld.wolfram.com/HyperspherePointPicking.html

The growing spheres method is originally proposed in the paper: "Comparison-based Inverse Classification for Interpretability in Machine Learning" by Thibaut Laugel et al (2018), presented at the International Conference on Information Processing and Management of Uncertainty in Knowledge-Based Systems (2018).

Arguments

  • n_search_samples::Int: The number of search samples (int > 0).
  • instance::AbstractArray: The input point array.
  • low::AbstractFloat: The lower bound (float >= 0, l < h).
  • high::AbstractFloat: The upper bound (float >= 0, h > l).
  • p_norm::Integer: The norm parameter (int >= 1).

Returns

  • candidate_counterfactuals::Array: An array of candidate counterfactuals.
source
CounterfactualExplanations.Generators.incompatible β€” Method
incompatible(AbstractGenerator, AbstractCounterfactualExplanation)

Checks if the generator is incompatible with any of the additional specifications for the counterfactual explanations. By default, generators are assumed to be compatible.

source
CounterfactualExplanations.Generators.propose_state β€” Method
propose_state(
    ::Models.IsDifferentiable,
    generator::AbstractGradientBasedGenerator,
    ce::AbstractCounterfactualExplanation,
)

Proposes new state based on backpropagation for gradient-based generators and differentiable models.

source
CounterfactualExplanations.Generators.βˆ‚h β€” Method
βˆ‚h(generator::AbstractGradientBasedGenerator, ce::AbstractCounterfactualExplanation)

The default method to compute the gradient of the complexity penalty at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl has gradient access.

If the penalty is not provided, it returns 0.0. By default, Zygote never works out the gradient for constants and instead returns 'nothing', so we need to add a manual step to override this behaviour. See here: https://discourse.julialang.org/t/zygote-gradient/26715.

source
CounterfactualExplanations.Generators.βˆ‚β„“ β€” Method
βˆ‚β„“(
    generator::AbstractGradientBasedGenerator,
    ce::AbstractCounterfactualExplanation,
)

The default method to compute the gradient of the loss function at the current counterfactual state for gradient-based generators. It assumes that Zygote.jl has gradient access.

source
CounterfactualExplanations.Generators.βˆ‡ β€” Method
βˆ‡(
    generator::AbstractGradientBasedGenerator,
    ce::AbstractCounterfactualExplanation,
)

The default method to compute the gradient of the counterfactual search objective for gradient-based generators. It simply computes the weighted sum over partial derivates. It assumes that Zygote.jl has gradient access. If the counterfactual is being generated using Probe, the hinge loss is added to the gradient.

source
CounterfactualExplanations.Objectives.distance_from_target β€” Method
distance_from_target(
    ce::AbstractCounterfactualExplanation;
    K::Int=50
)

Computes the distance of the counterfactual from samples in the target main. If choose_randomly is true, the function will randomly sample K neighbours from the target manifold. Otherwise, it will compute the pairwise distances and select the K closest neighbours.

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation.
  • K::Int=50: The number of neighbours to sample.
  • choose_randomly::Bool=true: Whether to sample neighbours randomly.
  • kwrgs...: Additional keyword arguments for the distance function.

Returns

  • Ξ”::AbstractFloat: The distance from the counterfactual to the target manifold.
source
CounterfactualExplanations.Objectives.energy β€” Method
energy(M::AbstractModel, x::AbstractArray, t::Int)

Computes the energy of the model at a given state as in Altmeyer et al. (2024): https://scholar.google.com/scholar?cluster=3697701546144846732&hl=en&as_sdt=0,5.

source
CounterfactualExplanations.Objectives.energy_constraint β€” Method
energy_constraint(
    ce::AbstractCounterfactualExplanation;
    agg=mean,
    reg_strength::AbstractFloat=0.0,
    decay::AbstractFloat=0.9,
    kwargs...,
)

Computes the energy constraint for the counterfactual explanation as in Altmeyer et al. (2024): https://scholar.google.com/scholar?cluster=3697701546144846732&hl=en&as_sdt=0,5. The energy constraint is a regularization term that penalizes the energy of the counterfactuals. The energy is computed as the negative logit of the target class.

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation.
  • agg::Function=mean: The aggregation function (only applicable in case num_counterfactuals > 1). Default is mean.
  • reg_strength::AbstractFloat=0.0: The regularization strength.
  • decay::AbstractFloat=0.9: The decay rate for the polynomial decay function (defaults to 0.9). Parameter a is set to 1.0 / ce.generator.opt.eta, such that the initial step size is equal to 1.0, not accounting for b. Parameter b is set to round(Int, max_steps / 20), where max_steps is the maximum number of iterations.
  • kwargs...: Additional keyword arguments.

Returns

  • β„’::AbstractFloat: The energy constraint.
source

Extensions

CounterfactualExplanations.Models.Model β€” Method
(M::Models.Model)(
    data::CounterfactualData,
    type::CounterfactualExplanations.DecisionTreeModel;
    kwargs...,
)

Constructs a decision tree for the given data. This method is used internally when a decision-tree model is constructed to be trained from scratch (i.e. no pre-trained model is supplied by the user).

source
DecisionTreeExt.calculate_delta β€” Method
calculate_delta(ce::AbstractCounterfactualExplanation, penalty::Vector{Function})

Calculates the penalty for the proposed feature tweak.

Arguments

  • ce::AbstractCounterfactualExplanation: The counterfactual explanation object.

Returns

  • delta::Float64: The calculated penalty for the proposed feature tweak.
source
DecisionTreeExt.classify_prototypes β€” Method
classify_prototypes(prototypes, rule_assignments, bounds)

Builds the second tree model using the given prototypes as inputs and their corresponding rule_assignments as labels. Split thresholds are restricted to the bounds, which can be computed using partition_bounds(rules). For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.cre β€” Method
cre(rules, x, X)

Computes the counterfactual rule explanations (CRE) for a given point $x$ and a set of $rules$, where the $rules$ correspond to the set of maximal-valid rules for some given target. For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.esatisfactory_instance β€” Method
esatisfactory_instance(generator::FeatureTweakGenerator, x::AbstractArray, paths::Dict{String, Dict{String, Any}})

Returns an epsilon-satisfactory counterfactual for x based on the paths provided.

Arguments

  • generator::FeatureTweakGenerator: The feature tweak generator.
  • x::AbstractArray: The factual instance.
  • paths::Dict{String, Dict{String, Any}}: A list of paths to the leaves of the tree to be used for tweaking the feature.

Returns

  • esatisfactory::AbstractArray: The epsilon-satisfactory instance.

Example

esatisfactory = esatisfactory_instance(generator, x, paths) # returns an epsilon-satisfactory counterfactual for x based on the paths provided

source
DecisionTreeExt.extract_leaf_rules β€” Method
extract_leaf_rules(root::DT.Root)

Extracts leaf decision rules (i.e. hyperrectangles) from a decision tree (root). For a decision tree with $L$ leaves this results in $L$ hyperrectangles. The rules are returned as a vector of tuples containing 2-element tuples, where each 2-element tuple stores the lower and upper bound imposed by the given rule for a given feature. For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.extract_rules β€” Method
extract_rules(root::DT.Root)

Extracts decision rules (i.e. hyperrectangles) from a decision tree (root). For a decision tree with $L$ leaves this results in $2L-1$ hyperrectangles. The rules are returned as a vector of vectors of 2-element tuples, where each tuple stores the lower and upper bound imposed by the given rule for a given feature. For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.get_individual_classifiers β€” Method
get_individual_classifiers(M::Model)

Returns the individual classifiers in the forest. If the input is a decision tree, the method returns the decision tree itself inside an array.

Arguments

  • M::Model: The model selected by the user.
  • model::CounterfactualExplanations.D

Returns

  • classifiers::AbstractArray: An array of individual classifiers in the forest.
source
DecisionTreeExt.issubrule β€” Method
issubrule(rule, otherrule)

Checks if the rule hyperrectangle is a subset of the otherrule hyperrectangle. For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.max_valid β€” Method
max_valid(rules, X, fx, target, Ο„)

Returns the maximal-valid rules for a given target and accuracy threshold Ο„. For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.prototype β€” Method
prototype(rule, X; pick_arbitrary::Bool=true)

Picks an arbitrary point $x^C \in X$ (i.e. prototype) from the subet of $X$ that is contained by rule $R_i$. If pick_arbitrary is set to false, the prototype is instead computed as the average across all samples. For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.rule_accuracy β€” Method
rule_accuracy(rule, X, fx, target)

Computes the accuracy of the rule on the data X for predicted outputs fx and the target. Accuracy is defined as the fraction of points contained by the rule, for which predicted values match the target. For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.rule_cost β€” Method
rule_cost(rule, x, X)

Computes the cost for $x$ to be contained by rule $R_i$, where cost is defined as rule_changes(rule, x) - rule_feasibility(rule, X). For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.rule_feasibility β€” Method
rule_feasibility(rule, X)

Computes the feasibility of a rule $R_i$ for a given dataset. Feasibility is defined as fraction of the data points that satisfy the rule. For details see Bewley et al. (2024) [arXiv, PMLR].

source
DecisionTreeExt.search_path β€” Function
search_path(tree::Union{DT.Leaf, DT.Node}, target::RawTargetType, path::AbstractArray)

Return a path index list with the inequality symbols, thresholds and feature indices.

Arguments

  • tree::Union{DT.Leaf, DT.Node}: The root node of a decision tree.
  • target::RawTargetType: The target class.
  • path::AbstractArray: A list containing the paths found thus far.

Returns

  • paths::AbstractArray: A list of paths to the leaves of the tree to be used for tweaking the feature.

Example

paths = search_path(tree, target) # returns a list of paths to the leaves of the tree to be used for tweaking the feature

source
CounterfactualExplanations.JEM β€” Method
CounterfactualExplanations.JEM(
    model::JointEnergyModels.JointEnergyClassifier; likelihood::Symbol=:classification_multi
)

Outer constructor for a neural network with Laplace Approximation from LaplaceRedux.jl.

source
CounterfactualExplanations.Models.logits β€” Method
Models.logits(M::JEM, X::AbstractArray)

Calculates the logit scores output by the model M for the input data X.

Arguments

  • M::JEM: The model selected by the user. Must be a model from the MLJ library.
  • X::AbstractArray: The feature vector for which the logit scores are calculated.

Returns

  • logits::Matrix: A matrix of logits for each output class for each data point in X.

Example

logits = Models.logits(M, x) # calculates the logit scores for each output class for the data point x

source
CounterfactualExplanations.Models.train β€” Method
train(M::JEM, data::CounterfactualData; kwargs...)

Fits the model M to the data in the CounterfactualData object. This method is not called by the user directly.

Arguments

  • M::JEM: The wrapper for an JEM model.
  • data::CounterfactualData: The CounterfactualData object containing the data to be used for training the model.

Returns

  • M::JEM: The fitted JEM model.
source
CounterfactualExplanations.LaplaceReduxModel β€” Method
CounterfactualExplanations.LaplaceReduxModel(
    model::LaplaceRedux.Laplace; likelihood::Symbol=:classification_binary
)

Outer constructor for a neural network with Laplace Approximation from LaplaceRedux.jl.

source
CounterfactualExplanations.Models.train β€” Method
train(M::LaplaceReduxModel, data::CounterfactualData; kwargs...)

Fits the model M to the data in the CounterfactualData object. This method is not called by the user directly.

Arguments

  • M::LaplaceReduxModel: The wrapper for an LaplaceReduxModel model.
  • data::CounterfactualData: The CounterfactualData object containing the data to be used for training the model.

Returns

  • M::LaplaceReduxModel: The fitted LaplaceReduxModel model.
source
CounterfactualExplanations.NeuroTreeModel β€” Method
CounterfactualExplanations.NeuroTreeModel(
    model::AtomicNeuroTree; likelihood::Symbol=:classification_binary
)

Outer constructor for a differentiable tree-based model from NeuroTreeModels.jl.

source
CounterfactualExplanations.Models.logits β€” Method
Models.logits(M::NeuroTreeModel, X::AbstractArray)

Calculates the logit scores output by the model M for the input data X.

Arguments

  • M::NeuroTreeModel: The model selected by the user. Must be a model from the MLJ library.
  • X::AbstractArray: The feature vector for which the logit scores are calculated.

Returns

  • logits::Matrix: A matrix of logits for each output class for each data point in X.

Example

logits = Models.logits(M, x) # calculates the logit scores for each output class for the data point x

source
CounterfactualExplanations.Models.train β€” Method
train(M::NeuroTreeModel, data::CounterfactualData; kwargs...)

Fits the model M to the data in the CounterfactualData object. This method is not called by the user directly.

Arguments

  • M::NeuroTreeModel: The wrapper for an NeuroTree model.
  • data::CounterfactualData: The CounterfactualData object containing the data to be used for training the model.

Returns

  • M::NeuroTreeModel: The fitted NeuroTree model.
source