From d53be07c6f863959caa03f19c4e76a406ff7b49b Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Wed, 18 Oct 2023 13:03:24 +0200 Subject: [PATCH] Add Sequential IG method (#222) --- README.md | 25 +- .../main_classes/feature_attribution.rst | 7 +- inseq/attr/feat/__init__.py | 2 + inseq/attr/feat/gradient_attribution.py | 18 +- inseq/attr/feat/internals_attribution.py | 2 - inseq/attr/feat/ops/__init__.py | 2 + .../ops/discretized_integrated_gradients.py | 2 - inseq/attr/feat/ops/lime.py | 1 - .../ops/sequential_integrated_gradients.py | 465 ++++++++++++++++++ inseq/models/model_config.py | 2 + 10 files changed, 506 insertions(+), 20 deletions(-) create mode 100644 inseq/attr/feat/ops/sequential_integrated_gradients.py diff --git a/README.md b/README.md index 302cc449..92c8a5e8 100644 --- a/README.md +++ b/README.md @@ -130,27 +130,29 @@ Use the `inseq.list_feature_attribution_methods` function to list all available #### Gradient-based attribution -- `saliency`: [Saliency](https://arxiv.org/abs/1312.6034) (Simonyan et al., 2013) +- `saliency`: [Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps](https://arxiv.org/abs/1312.6034) (Simonyan et al., 2013) -- `input_x_gradient`: [Input x Gradient](https://arxiv.org/abs/1312.6034) (Simonyan et al., 2013) +- `input_x_gradient`: [Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps](https://arxiv.org/abs/1312.6034) (Simonyan et al., 2013) -- `integrated_gradients`: [Integrated Gradients](https://arxiv.org/abs/1703.01365) (Sundararajan et al., 2017) +- `integrated_gradients`: [Axiomatic Attribution for Deep Networks](https://arxiv.org/abs/1703.01365) (Sundararajan et al., 2017) -- `deeplift`: [DeepLIFT](https://arxiv.org/abs/1704.02685) (Shrikumar et al., 2017) +- `deeplift`: [Learning Important Features Through Propagating Activation Differences](https://arxiv.org/abs/1704.02685) (Shrikumar et al., 2017) -- `gradient_shap`: [Gradient SHAP](https://dl.acm.org/doi/10.5555/3295222.3295230) (Lundberg and Lee, 2017) +- `gradient_shap`: [A unified approach to interpreting model predictions](https://dl.acm.org/doi/10.5555/3295222.3295230) (Lundberg and Lee, 2017) -- `discretized_integrated_gradients`: [Discretized Integrated Gradients](https://aclanthology.org/2021.emnlp-main.805/) (Sanyal and Ren, 2021) +- `discretized_integrated_gradients`: [Discretized Integrated Gradients for Explaining Language Models](https://aclanthology.org/2021.emnlp-main.805/) (Sanyal and Ren, 2021) + +- `sequential_integrated_gradients`: [Sequential Integrated Gradients: a simple but effective method for explaining language models](https://aclanthology.org/2023.findings-acl.477/) (Enguehard, 2023) #### Internals-based attribution -- `attention`: [Attention Weight Attribution](https://arxiv.org/abs/1409.0473) (Bahdanau et al., 2014) +- `attention`: Attention Weight Attribution, from [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473) (Bahdanau et al., 2014) #### Perturbation-based attribution -- `occlusion`: [Occlusion](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) (Zeiler and Fergus, 2014) +- `occlusion`: [Visualizing and Understanding Convolutional Networks](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_53) (Zeiler and Fergus, 2014) -- `lime`: [LIME](https://arxiv.org/abs/1602.04938) (Ribeiro et al., 2016) +- `lime`: ["Why Should I Trust You?": Explaining the Predictions of Any Classifier](https://arxiv.org/abs/1602.04938) (Ribeiro et al., 2016) #### Step functions @@ -262,9 +264,10 @@ Inseq has been used in various research projects. A list of known publications t
2023
    -
  1. Inseq: An Interpretability Toolkit for Sequence Generation Models (Sarti et al., 2023)
  2. +
  3. Inseq: An Interpretability Toolkit for Sequence Generation Models (Sarti et al., 2023)
  4. Are Character-level Translations Worth the Wait? Comparing Character- and Subword-level Models for Machine Translation (Edman et al., 2023)
  5. -
  6. Response Generation in Longitudinal Dialogues: Which Knowledge Representation Helps? (Mousavi et al., 2023)
  7. +
  8. Response Generation in Longitudinal Dialogues: Which Knowledge Representation Helps? (Mousavi et al., 2023)
  9. +
  10. Quantifying the Plausibility of Context Reliance in Neural Machine Translation (Sarti et al., 2023)
diff --git a/docs/source/main_classes/feature_attribution.rst b/docs/source/main_classes/feature_attribution.rst index e45aabe0..174b405c 100644 --- a/docs/source/main_classes/feature_attribution.rst +++ b/docs/source/main_classes/feature_attribution.rst @@ -28,9 +28,6 @@ Gradient Attribution Methods :members: -.. warning:: - The DiscretizedIntegratedGradientsAttribution class is currently exhibiting inconsistent behavior, so usage should be limited until further notice. See PR `# 114 `__ for additional info. - .. autoclass:: inseq.attr.feat.DiscretizedIntegratedGradientsAttribution :members: @@ -50,6 +47,10 @@ Gradient Attribution Methods :members: +.. autoclass:: inseq.attr.feat.SequentialIntegratedGradientsAttribution + :members: + + Layer Attribution Methods ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/inseq/attr/feat/__init__.py b/inseq/attr/feat/__init__.py index 14183894..cc07f530 100644 --- a/inseq/attr/feat/__init__.py +++ b/inseq/attr/feat/__init__.py @@ -11,6 +11,7 @@ LayerGradientXActivationAttribution, LayerIntegratedGradientsAttribution, SaliencyAttribution, + SequentialIntegratedGradientsAttribution, ) from .internals_attribution import AttentionWeightsAttribution, InternalsAttributionRegistry from .perturbation_attribution import ( @@ -37,4 +38,5 @@ "AttentionWeightsAttribution", "OcclusionAttribution", "LimeAttribution", + "SequentialIntegratedGradientsAttribution", ] diff --git a/inseq/attr/feat/gradient_attribution.py b/inseq/attr/feat/gradient_attribution.py index 6b354e5e..b86167e5 100644 --- a/inseq/attr/feat/gradient_attribution.py +++ b/inseq/attr/feat/gradient_attribution.py @@ -32,7 +32,7 @@ from ..attribution_decorators import set_hook, unset_hook from .attribution_utils import get_source_target_attributions from .feature_attribution import FeatureAttribution -from .ops import DiscretetizedIntegratedGradients +from .ops import DiscretetizedIntegratedGradients, SequentialIntegratedGradients logger = logging.getLogger(__name__) @@ -212,6 +212,22 @@ def __init__(self, attribution_model): self.method = Saliency(self.attribution_model) +class SequentialIntegratedGradientsAttribution(GradientAttributionRegistry): + """Sequential Integrated Gradients attribution method. + + Reference: https://aclanthology.org/2023.findings-acl.477/ + + Original implementation: https://github.com/josephenguehard/time_interpret/blob/main/tint/attr/seq_ig.py + """ + + method_name = "sequential_integrated_gradients" + + def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs): + super().__init__(attribution_model) + self.method = SequentialIntegratedGradients(self.attribution_model, multiply_by_inputs) + self.use_baselines = True + + # Layer methods diff --git a/inseq/attr/feat/internals_attribution.py b/inseq/attr/feat/internals_attribution.py index a30ccb5e..226246fd 100644 --- a/inseq/attr/feat/internals_attribution.py +++ b/inseq/attr/feat/internals_attribution.py @@ -18,7 +18,6 @@ from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.attr._utils.attribution import Attribution -from captum.log import log_usage from ...data import MultiDimensionalFeatureAttributionStepOutput from ...utils import Registry @@ -43,7 +42,6 @@ class AttentionWeights(Attribution): def has_convergence_delta() -> bool: return False - @log_usage() def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, diff --git a/inseq/attr/feat/ops/__init__.py b/inseq/attr/feat/ops/__init__.py index 83439681..388ab042 100644 --- a/inseq/attr/feat/ops/__init__.py +++ b/inseq/attr/feat/ops/__init__.py @@ -1,9 +1,11 @@ from .discretized_integrated_gradients import DiscretetizedIntegratedGradients from .lime import Lime from .monotonic_path_builder import MonotonicPathBuilder +from .sequential_integrated_gradients import SequentialIntegratedGradients __all__ = [ "DiscretetizedIntegratedGradients", "MonotonicPathBuilder", "Lime", + "SequentialIntegratedGradients", ] diff --git a/inseq/attr/feat/ops/discretized_integrated_gradients.py b/inseq/attr/feat/ops/discretized_integrated_gradients.py index b0cd0655..73b06bf7 100644 --- a/inseq/attr/feat/ops/discretized_integrated_gradients.py +++ b/inseq/attr/feat/ops/discretized_integrated_gradients.py @@ -31,7 +31,6 @@ from captum.attr._core.integrated_gradients import IntegratedGradients from captum.attr._utils.batching import _batch_attribution from captum.attr._utils.common import _format_input_baseline, _reshape_and_sum, _validate_input -from captum.log import log_usage from torch import Tensor from ....utils import INSEQ_ARTIFACTS_CACHE @@ -87,7 +86,6 @@ def get_inputs_baselines(scaled_features_tpl: Tuple[Tensor, ...], n_steps: int) ) return inputs, baselines - @log_usage() def attribute( # type: ignore self, inputs: MultiStepEmbeddingsTensor, diff --git a/inseq/attr/feat/ops/lime.py b/inseq/attr/feat/ops/lime.py index 5b2b040f..eedf1f8a 100644 --- a/inseq/attr/feat/ops/lime.py +++ b/inseq/attr/feat/ops/lime.py @@ -62,7 +62,6 @@ def __init__( ) self.attribution_model = attribution_model - # @log_usage def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, diff --git a/inseq/attr/feat/ops/sequential_integrated_gradients.py b/inseq/attr/feat/ops/sequential_integrated_gradients.py new file mode 100644 index 00000000..92322815 --- /dev/null +++ b/inseq/attr/feat/ops/sequential_integrated_gradients.py @@ -0,0 +1,465 @@ +# Adapted from https://github.com/josephenguehard/time_interpret/blob/main/tint/attr/seq_ig.py, licensed MIT: +# Copyright © 2023 Babylon Health + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +# associated documentation files (the “Software”), to deal in the Software without restriction, +# including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies +# or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT +# LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE +# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import typing +from typing import Any, Callable, List, Tuple, Union + +import torch +from captum._utils.common import ( + _expand_additional_forward_args, + _expand_target, + _format_additional_forward_args, + _format_output, + _is_tuple, +) +from captum._utils.typing import ( + BaselineType, + Literal, + TargetType, + TensorOrTupleOfTensorsGeneric, +) +from captum.attr._utils.approximation_methods import approximation_parameters +from captum.attr._utils.attribution import GradientAttribution +from captum.attr._utils.batching import _batch_attribution +from captum.attr._utils.common import ( + _format_input_baseline, + _reshape_and_sum, + _validate_input, +) +from torch import Tensor + + +class SequentialIntegratedGradients(GradientAttribution): + r""" + Sequential Integrated Gradients. + + This method is the regular Integrated Gradients (IG) applied on each + component of a sequence. However, the baseline is specific to each + component: it keeps fixed the rest of the sequence while only setting the + component of interest to a reference baseline. + + For instance, on a setence of m words, the attribution of each word is + computed by running IG with a specific baseline: fixing every other word + to their current value, and replacing the word of interest with "", + an uninformative baseline. + + This method can be computationally expensive on long sequences, as it + needs to compute IG on each component individually. It is therefore + suggested to reduce ``n_steps`` when using this method on long sequences. + + Args: + forward_func (callable): The forward function of the model or any + modification of it + multiply_by_inputs (bool, optional): Indicates whether to factor + model inputs' multiplier in the final attribution scores. + In the literature this is also known as local vs global + attribution. If inputs' multiplier isn't factored in, + then that type of attribution method is also called local + attribution. If it is, then that type of attribution + method is called global. + More detailed can be found here: + https://arxiv.org/abs/1711.06104 + + In case of integrated gradients, if `multiply_by_inputs` + is set to True, final sensitivity scores are being multiplied by + (inputs - baselines). + + References: + `Sequential Integrated Gradients: a simple but effective method for explaining language models + `_ + + Examples: + >>> import torch as th + >>> from tint.attr import SequentialIntegratedGradients + >>> from tint.models import MLP + + >>> inputs = th.rand(8, 7, 5) + >>> mlp = MLP([5, 3, 1]) + + >>> explainer = SequentialIntegratedGradients(mlp) + >>> attr = explainer.attribute(inputs, target=0) + """ + + def __init__( + self, + forward_func: Callable, + multiply_by_inputs: bool = True, + ) -> None: + r""" + Args: + + + """ + GradientAttribution.__init__(self, forward_func) + self._multiply_by_inputs = multiply_by_inputs + + # The following overloaded method signatures correspond to the case where + # return_convergence_delta is False, then only attributions are returned, + # and when return_convergence_delta is True, the return type is + # a tuple with both attributions and deltas. + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + return_convergence_delta: Literal[False] = False, + ) -> TensorOrTupleOfTensorsGeneric: + ... + + @typing.overload + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + *, + return_convergence_delta: Literal[True], + ) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]: + ... + + def attribute( # type: ignore + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = None, + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + internal_batch_size: Union[None, int] = None, + return_convergence_delta: bool = False, + ) -> Union[TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor],]: + r""" + This method attributes the output of the model with given target index + (in case it is provided, otherwise it assumes that output is a + scalar) to the inputs of the model using the approach described above. + + In addition to that it also returns, if `return_convergence_delta` is + set to True, integral approximation delta based on the completeness + property of integrated gradients. + + Args: + + inputs (tensor or tuple of tensors): Input for which integrated + gradients are computed. If forward_func takes a single + tensor as input, a single input tensor should be provided. + If forward_func takes multiple tensors as input, a tuple + of the input tensors should be provided. It is assumed + that for all given input tensors, dimension 0 corresponds + to the number of examples, and if multiple input tensors + are provided, the examples must be aligned appropriately. + baselines (scalar, tensor, tuple of scalars or tensors, optional): + Baselines define the starting point from which integral + is computed and can be provided as: + + - a single tensor, if inputs is a single tensor, with + exactly the same dimensions as inputs or the first + dimension is one and the remaining dimensions match + with inputs. + + - a single scalar, if inputs is a single tensor, which will + be broadcasted for each input value in input tensor. + + - a tuple of tensors or scalars, the baseline corresponding + to each tensor in the inputs' tuple can be: + + - either a tensor with matching dimensions to + corresponding tensor in the inputs' tuple + or the first dimension is one and the remaining + dimensions match with the corresponding + input tensor. + + - or a scalar, corresponding to a tensor in the + inputs' tuple. This scalar value is broadcasted + for corresponding input tensor. + + In the cases when `baselines` is not provided, we internally + use zero scalar corresponding to each input tensor. + + Default: None + target (int, tuple, tensor or list, optional): Output indices for + which gradients are computed (for classification cases, + this is usually the target class). + If the network returns a scalar value per example, + no target index is necessary. + For general 2D outputs, targets can be either: + + - a single integer or a tensor containing a single + integer, which is applied to all input examples + + - a list of integers or a 1D tensor, with length matching + the number of examples in inputs (dim 0). Each integer + is applied as the target for the corresponding example. + + For outputs with > 2 dimensions, targets can be either: + + - A single tuple, which contains #output_dims - 1 + elements. This target index is applied to all examples. + + - A list of tuples with length equal to the number of + examples in inputs (dim 0), and each tuple containing + #output_dims - 1 elements. Each tuple is applied as the + target for the corresponding example. + + Default: None + additional_forward_args (any, optional): If the forward function + requires additional arguments other than the inputs for + which attributions should not be computed, this argument + can be provided. It must be either a single additional + argument of a Tensor or arbitrary (non-tuple) type or a + tuple containing multiple additional arguments including + tensors or any arbitrary python types. These arguments + are provided to forward_func in order following the + arguments in inputs. + For a tensor, the first dimension of the tensor must + correspond to the number of examples. It will be + repeated for each of `n_steps` along the integrated + path. For all other types, the given argument is used + for all forward evaluations. + Note that attributions are not computed with respect + to these arguments. + Default: None + n_steps (int, optional): The number of steps used by the approximation + method. Default: 50. + method (string, optional): Method for approximating the integral, + one of `riemann_right`, `riemann_left`, `riemann_middle`, + `riemann_trapezoid` or `gausslegendre`. + Default: `gausslegendre` if no method is provided. + internal_batch_size (int, optional): Divides total #steps * #examples + data points into chunks of size at most internal_batch_size, + which are computed (forward / backward passes) + sequentially. internal_batch_size must be at least equal to + #examples. + For DataParallel models, each batch is split among the + available devices, so evaluations on each available + device contain internal_batch_size / num_devices examples. + If internal_batch_size is None, then all evaluations are + processed in one batch. + Default: None + return_convergence_delta (bool, optional): Indicates whether to return + convergence delta or not. If `return_convergence_delta` + is set to True convergence delta will be returned in + a tuple following attributions. + Default: False + + Returns: + **attributions** or 2-element tuple of **attributions**, **delta**: + - **attributions** (*tensor* or tuple of *tensors*): + Integrated gradients with respect to each input feature. + attributions will always be the same size as the provided + inputs, with each value providing the attribution of the + corresponding input index. + If a single tensor is provided as inputs, a single tensor is + returned. If a tuple is provided for inputs, a tuple of + corresponding sized tensors is returned. + - **delta** (*tensor*, returned if return_convergence_delta=True): + The difference between the total approximated and true + integrated gradients. This is computed using the property + that the total sum of forward_func(inputs) - + forward_func(baselines) must equal the total sum of the + integrated gradient. + Delta is calculated per example, meaning that the number of + elements in returned delta tensor is equal to the number of + of examples in inputs. + + Examples:: + + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, + >>> # and returns an Nx10 tensor of class probabilities. + >>> net = ImageClassifier() + >>> sig = SequentialIntegratedGradients(net) + >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) + >>> # Computes integrated gradients for class 3. + >>> attribution = sig.attribute(input, target=3) + """ + # Keeps track whether original input is a tuple or not before + # converting it into a tuple. + is_inputs_tuple = _is_tuple(inputs) + + inputs, baselines = _format_input_baseline(inputs, baselines) + + _validate_input(inputs, baselines, n_steps, method) + + assert all( + x.shape[1] == inputs[0].shape[1] for x in inputs + ), "All inputs must have the same sequential dimension. (dimension 1)" + + indexes = range(inputs[0].shape[1]) + + # Loop over the sequence + attributions_partial_list = [] + for idx in indexes: + if internal_batch_size is not None: + num_examples = inputs[0].shape[0] + attributions_partial = _batch_attribution( + self, + num_examples, + internal_batch_size, + n_steps, + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + method=method, + idx=idx, + ) + else: + attributions_partial = self._attribute( + inputs=inputs, + baselines=baselines, + target=target, + additional_forward_args=additional_forward_args, + n_steps=n_steps, + method=method, + idx=idx, + ) + + attributions_partial_list.append(attributions_partial) + + # Merge collected attributions + attributions = () + for i in range(len(attributions_partial_list[0])): + attributions += ( + torch.stack( + [x[i][:, idx, ...] for idx, x in enumerate(attributions_partial_list)], + dim=1, + ), + ) + + if return_convergence_delta: + start_point, end_point = baselines, inputs + # computes approximation error based on the completeness axiom + delta = self.compute_convergence_delta( + attributions, + start_point, + end_point, + additional_forward_args=additional_forward_args, + target=target, + ) + return _format_output(is_inputs_tuple, attributions), delta + return _format_output(is_inputs_tuple, attributions) + + def _attribute( + self, + inputs: Tuple[Tensor, ...], + baselines: Tuple[Union[Tensor, int, float], ...], + target: TargetType = None, + additional_forward_args: Any = None, + n_steps: int = 50, + method: str = "gausslegendre", + idx: int = None, + step_sizes_and_alphas: Union[None, Tuple[List[float], List[float]]] = None, + ) -> Tuple[Tensor, ...]: + if step_sizes_and_alphas is None: + # retrieve step size and scaling factor for specified + # approximation method + step_sizes_func, alphas_func = approximation_parameters(method) + step_sizes, alphas = step_sizes_func(n_steps), alphas_func(n_steps) + else: + step_sizes, alphas = step_sizes_and_alphas + + # Keep only idx index if baselines is a tensor + baselines_ = tuple( + baseline[:, idx, ...] if isinstance(baseline, Tensor) else baseline for baseline in baselines + ) + + # scale features and compute gradients. (batch size is abbreviated as bsz) + # scaled_features' dim -> (bsz * #steps x inputs[0].shape[1:], ...) + # Only scale features on the idx index. + scaled_features_tpl = tuple( + torch.cat( + [ + torch.cat( + [input[:, :idx, ...] for _ in alphas], + dim=0, + ).requires_grad_(), + torch.cat( + [baseline + alpha * (input[:, idx, ...] - baseline) for alpha in alphas], + dim=0, + ) + .unsqueeze(1) + .requires_grad_(), + torch.cat( + [input[:, idx + 1 :, ...] for _ in alphas], + dim=0, + ).requires_grad_(), + ], + dim=1, + ) + for input, baseline in zip(inputs, baselines_) + ) + + additional_forward_args = _format_additional_forward_args(additional_forward_args) + # apply number of steps to additional forward args + # currently, number of steps is applied only to additional forward arguments + # that are nd-tensors. It is assumed that the first dimension is + # the number of batches. + # dim -> (bsz * #steps x additional_forward_args[0].shape[1:], ...) + input_additional_args = ( + _expand_additional_forward_args(additional_forward_args, n_steps) + if additional_forward_args is not None + else None + ) + expanded_target = _expand_target(target, n_steps) + + # grads: dim -> (bsz * #steps x inputs[0].shape[1:], ...) + grads = self.gradient_func( + forward_fn=self.forward_func, + inputs=scaled_features_tpl, + target_ind=expanded_target, + additional_forward_args=input_additional_args, + ) + + # flattening grads so that we can multiply it with step-size + # calling contiguous to avoid `memory whole` problems + scaled_grads = [ + grad.contiguous().view(n_steps, -1) * torch.tensor(step_sizes).view(n_steps, 1).to(grad.device) + for grad in grads + ] + + # aggregates across all steps for each tensor in the input tuple + # total_grads has the same dimensionality as inputs + total_grads = tuple( + _reshape_and_sum(scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:]) + for (scaled_grad, grad) in zip(scaled_grads, grads) + ) + + # computes attribution for each tensor in input tuple + # attributions has the same dimensionality as inputs + if not self.multiplies_by_inputs: + attributions = total_grads + else: + attributions = tuple( + total_grad * (input - baseline) for total_grad, input, baseline in zip(total_grads, inputs, baselines) + ) + return attributions + + def has_convergence_delta(self) -> bool: + return True + + @property + def multiplies_by_inputs(self): + return self._multiply_by_inputs diff --git a/inseq/models/model_config.py b/inseq/models/model_config.py index 3db75760..05b8a468 100644 --- a/inseq/models/model_config.py +++ b/inseq/models/model_config.py @@ -45,6 +45,7 @@ def register_model_config( ) -> None: """Allows to register a model configuration for a given model type. The configuration is a dictionary containing information required the methods for which the attribute ``use_model_config=True``. + Args: model_type (`str`): The class of the model for which the configuration is registered, used as key in the stored configuration. @@ -57,6 +58,7 @@ def register_model_config( allow_partial (`bool`, *optional*, defaults to False): If `True`, the configuration can be partial, i.e. it can contain only a subset of the fields of the :class:`~inseq.models.ModelConfig` class. The missing fields will be set to `None`. + Raises: `ValueError`: If the model type is already registered and `overwrite=False`, or if the configuration is partial and `allow_partial=False`.