Skip to content

Add the ReX algorithm to Captum #1642

@stav-af

Description

@stav-af

🚀 Feature

ReX is a perturbation-based approach for interpreting the predictions of a model for a given input. It returns an attribution tensor, mapping inputs to their degree of responsibility [2] in determining the models output. This approach is grounded in the theory of Actual Causality [4].

For a given prediction function, input, and baseline, the algorithm works by partitioning the input and masking combinations of those partitions. Where masking a partition, or combination of partitions alters the prediction, the partition(s) are recursively re-partitioned and masked to search for minimal responsible partitions.

Requires:

  • forward_func: Callable -> Scalar
  • inputs: TensorOrTupleOfTensor
  • baselines: BaselineType

PseudoCode:

responsibility(mut, valid_mutants):
  witnesses = filter(valid_mutants, lambda vmut: (mut + vmut) not in valid_mutants)
  k = len(min(witnesses))
  return 1/(1+k)

ReX(input, baseline) -> responsibility map:
  responsibility_map = zeroes_like(input)
  original_prediction = forward_func(input)
  do N times:
    Q <- input, 1
    while Q:
      parent_partition, parent_responsibility = Q.pop()
      mutants = partition_and_mask(parent_partition, responsibility)
      valid_mutants = filter(mutants, lambda mut: forward_func(mut) == original_prediction)
      for mut in mutants:
        responsibility = responsibility(mut, valid_mutants)
        for idx in mut:
          responsibility_map[idx] = responsibility/len(mut)
      
        if responsibility > 0 and len(mut) > 1:
          Q.push(mut, responsibility)
  
  return responsibility_map

Where N denotes the number of refinement passes.

This algorithm is based on the ReX papers[1], with the caveat that this implementation (optionally) does not assume locality of responsibility (as it could for image classifiers), and does not assume contiguous partitioning. The full implementation is provided in the following draft pull request: https://github.com/stav-af/captum-sandbox/pulls

Design Considerations

  • ReX is perturbation-based, and model agnostic. Hence it should extend the PerturbationAttribution class, and be instantiated only by a forward_func
  • Unlike other interpretability methods, ReX is based on a heursitic search. Hence it demands different keyword arguments than other classes. n_searches, n_partitions, and search_depth should be used as arguments to denote number of refinement passes, the branching factor, and the maximum search depth. Additionally, the assume_locality parameter is added for control over partitioning strategy
  • Similarly to other attribution methods, ReX returns an attribution map.
  • ReX has complexity O(min{|input|, n_searches * n_partitions ^ search_depth})

Proposed API

Though ReX demands the same required arguments as other PerturbationAttribution methods, it's keyword arguments are different due to it's recursive nature. Hence it's constructor is identical to other methods:

ReX(forward_func: Callable)

Where forward_func must return a scalar value

though it's attribute method has different keyword arguments:

    def attribute(self,
                  inputs: TensorOrTupleOfTensorsGeneric,
                  baselines: BaselineType = 0,
                  search_depth: int = 10,
                  n_partitions: int = 4,
                  n_searches: int = 5,
                  assume_locality: bool = False
) -> TensorOrTupleOfTensorsGeneric:

Where:

  • search_depth: Implies the maximum depth of the recursive search
  • n_partitions: The maximum number of sub-partitions to make out of each partition at each search step. This also implies the branching factor
  • n_searches: The number of times to repeat the search, using previously computed information to direct the search.
  • assume_locality: Used to switch between contiguous partitioning and heuristic partitioning. Where True, the partitioning strategy round-robins each dimension in the input tensor, randomly dropping split points. Furthermore, the returned map is the average map over all searches. Where False, partitioning uses the results of previous searches to direct partitioning. Specifically, the strategy attempts to assign indices of the input tensor to partitions such that each partition has equal responsibility. Intuitively, areas of high responsibility are partitioned more. In this strategy, the final map is returned.

References & Links:

  1. https://www.hanachockler.com/eccv/ https://www.hanachockler.com/iccv2021/ ReX papers (called DC-Causal and DeepCover in these)
  2. https://arxiv.org/abs/cs/0312038 Responsibility and Blame: a structured model approach
  3. https://github.com/ReX-XAI/ReX A comprehensive ReX implementation for image classifiers
  4. https://www.cs.cornell.edu/home/halpern/papers/causalitybook-ch1-3.html The Actual Causality book

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions