Skip to content

Commit

Permalink
treescope support and new visualizations (inseq-team#283)
Browse files Browse the repository at this point in the history
* Add treescope requirement

* Drop Python 3.9 support

* Fix tornado version for safety

* show_granular first version working

* Add __treescope_repr__ to FeatureAttributionSequenceOutput

* Add slicing for show_granular, started show_tokens

* Finished show_tokens

* Fix vmin for step scores

* Fix lint

* Add docs

* Update changelog

* Fix viz for attribute_context, improved cmaps

* Fix safety
  • Loading branch information
gsarti authored Aug 9, 2024
1 parent e5b835b commit 904c893
Show file tree
Hide file tree
Showing 62 changed files with 1,674 additions and 774 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
if: github.actor != 'dependabot[bot]' && github.actor != 'dependabot-preview[bot]'
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ sphinx:
build:
os: ubuntu-20.04
tools:
python: "3.9"
python: "3.10"

python:
install:
Expand Down
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

## 🚀 Features

- Added [treescope](https://github.com/google-deepmind/treescope) for interactive model and tensor visualization. ([#283](https://github.com/inseq-team/inseq/pull/283))

- New `treescope`-powered methods `FeatureAttributionOutput.show_granular` and `FeatureAttributionSequenceOutput.show_tokens` for interactive visualization of multidimensional attribution tensors and token highlights. ([#283](https://github.com/inseq-team/inseq/pull/283))

- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM`, `Gemma2ForCausalLM` to model config.

- Add `rescale_attributions` to Inseq CLI commands for `rescale=True` ([#280](https://github.com/inseq-team/inseq/pull/280)).
Expand Down Expand Up @@ -84,8 +88,8 @@ out_female = attrib_model.attribute(

## 📝 Documentation and Tutorials

*No changes*
- Updated tutorial with `treescope` usage examples.

## 💥 Breaking Changes

*No changes*
- Dropped support for Python 3.9. Please use Python >= 3.10. ([#283](https://github.com/inseq-team/inseq/pull/283))
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ fix-style:

.PHONY: check-safety
check-safety:
$(PYTHON) -m safety check --full-report -i 70612 -i 71670
$(PYTHON) -m safety check --full-report -i 70612 -i 71670 -i 72089

.PHONY: lint
lint: fix-style check-safety
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
[![Downloads](https://static.pepy.tech/badge/inseq)](https://pepy.tech/project/inseq)
[![License](https://img.shields.io/github/license/inseq-team/inseq)](https://github.com/inseq-team/inseq/blob/main/LICENSE)
[![Demo Paper](https://img.shields.io/badge/ACL%20Anthology%20-%20?logo=data%3Aimage%2Fx-icon%3Bbase64%2CAAABAAEAIBIAAAEAIABwCQAAFgAAACgAAAAgAAAAJAAAAAEAIAAAAAAAAAkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQd7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQd7f8kHO3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQd7f8kHe3%2FJB3t%2FyMc79EkGP8VAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHO3%2FJBzt%2FyQd7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQc7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJB3t%2FyQd7f8kHe3%2FIxzv0SQY%2FxUAAAAAAAAAAAAAAAAhIe5NJh%2Fv%2BSQd7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8kHe3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHe3%2FJB3t%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHe3%2FJBzt%2FyQd7f8jHO%2FRJBj%2FFQAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8kHe3%2FJB3t%2FyQc7f8kHe3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyMc79EkGP8VAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHO3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQc7f8kHe3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHe3%2FJBzt%2FyQd7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FIxzv0SQY%2FxUAAAAAAAAAAAAAAAAhGu5NJh%2Fv%2BSQd7f8kHO3%2FJBzt%2FyQd7f8jIOzYIxvtgiQc8X8kHPF%2FJBzxfyQc8X8kHPF%2FJBzxfyQc8X8kHPF%2FIx%2FuiiMf7OgkHe3%2FJBzt%2FyQc7f8kHO3%2FJhzs9CUg7JYkHPF%2FJBzxfyQc8X8iHfBoMzP%2FCgAAAAAAAAAAAAAAACEa7k0mHu%2F5JBzt%2FyQc7f8kHO3%2FJBzt%2FyQb7LEAAP8FAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkGP8VIxzv0SQc7f8kHe3%2FJB3t%2FyQc7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHe3%2FJB3t%2FyQd7f8kHO3%2FJBvssQAA%2FwUAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACQY%2FxUjHO%2FRJB3t%2FyQc7f8kHO3%2FJBzt%2FyMb7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAhGu5NJh%2Fv%2BSQd7f8kHe3%2FJBzt%2FyQc7f8kHuyxAAD%2FBQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJBj%2FFSMc79EkHe3%2FJBzt%2FyQc7f8kHO3%2FIxvs6SIc7i0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JBzt%2FyQc7f8kHO3%2FJBzt%2FyQb7LEAAP8FAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkGP8VIx3v0SQd7f8kHe3%2FJBzt%2FyQc7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHO3%2FJBzt%2FyQc7f8kHe3%2FJBzuxSgi81MhGu5NIRruTSEa7k0hGu5NISHuTSEh7k0hGu5NIRruTSIa72EjHe3aJBzt%2FyQd7f8kHO3%2FJBzt%2FyMc7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAhGu5NJh7v%2BSQc7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJh%2Fv%2BSYf7%2FkmHu%2F5Jh%2Fv%2BSYf7%2FkmH%2B%2F5Jh%2Fv%2BSYf7%2FkmH%2B%2F5Jh7v%2BSQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FIxzs6SIc7i0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mHu%2F5JBzt%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIRruTSYe7%2FkkHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQd7f8kHe3%2FJB3t%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyMc7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAhGu5NJh%2Fv%2BSQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FIxzs6SIc7i0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JB3t%2FyQc7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKB%2FtOSUc7askHuyxJBvssSQe7LEkHuyxJB7ssSQe7LEkHuyxJBvssSQb7LEkHuyxJB7ssSQe7LEkHuyxJB7ssSUc7LMjHe31JB3t%2FyQd7f8kHe3%2FJBzt%2FyMc7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FHBzsGyMd7qYjHO%2FRIxzv0SMd79EjHO%2FRIx7tux4Y%2BSoAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAABwAAAAcAAAAHAAAABwAAAAcAAAAHAP8A%2FwD%2FAP8A%2FwD%2FAP8A%2FwAAAP8AAAD%2FAAAA%2FwAAAP8AAAD%2FAAAA%2FwAAAP%2BAAAD8%3D&labelColor=white&color=red&link=https%3A%2F%2Faclanthology.org%2F2023.acl-demo.40%2F
)](http://arxiv.org/abs/2302.13942)
)](https://aclanthology.org/2023.acl-demo.40)

</div>
<div align="center">

[![Follow Inseq on Twitter]( https://img.shields.io/badge/Twitter-1DA1F2?style=for-the-badge&logo=twitter&logoColor=white)](https://twitter.com/InseqLib)
[![Follow Inseq on Twitter](https://img.shields.io/badge/Twitter-1DA1F2?style=for-the-badge&logo=twitter&logoColor=white)](https://twitter.com/InseqLib)
[![Join the Inseq Discord server](https://img.shields.io/badge/Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/V5VgwwFPbu)
[![Read the Docs](https://img.shields.io/badge/-Docs-blue?style=for-the-badge&logo=Read-the-Docs&logoColor=white&link=https://inseq.org)](https://inseq.org)
[![Tutorial](https://img.shields.io/badge/-Tutorial-orange?style=for-the-badge&logo=Jupyter&logoColor=white&link=https://github.com/inseq-team/inseq/blob/main/examples/inseq_tutorial.ipynb)](https://github.com/inseq-team/inseq/blob/main/examples/inseq_tutorial.ipynb)
Expand All @@ -30,7 +30,7 @@ Inseq is a Pytorch-based hackable toolkit to democratize access to common post-h

## Installation

Inseq is available on PyPI and can be installed with `pip` for Python >= 3.9, <= 3.12:
Inseq is available on PyPI and can be installed with `pip` for Python >= 3.10, <= 3.12:

```bash
# Install latest stable version
Expand Down
4 changes: 4 additions & 0 deletions docs/source/main_classes/main_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,8 @@ functionalities required for its usage.

.. autofunction:: show_attributions

.. autofunction:: show_granular_attributions

.. autofunction:: show_token_attributions

.. autofunction:: merge_attributions
456 changes: 273 additions & 183 deletions examples/inseq_tutorial.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions inseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
list_aggregators,
merge_attributions,
show_attributions,
show_granular_attributions,
show_token_attributions,
)
from .models import AttributionModel, list_supported_frameworks, load_model, register_model_config
from .utils.id_utils import explain
Expand All @@ -28,6 +30,8 @@ def get_version() -> str:
"load_model",
"explain",
"show_attributions",
"show_granular_attributions",
"show_token_attributions",
"list_feature_attribution_methods",
"list_aggregators",
"list_aggregation_functions",
Expand Down
14 changes: 8 additions & 6 deletions inseq/attr/attribution_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
"""Decorators for attribution methods."""

import logging
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import wraps
from typing import Any, Callable, Optional
from typing import Any

from ..data.data_utils import TensorWrapper

Expand Down Expand Up @@ -55,14 +55,14 @@ def batched(f: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator that enables batching of the args."""

@wraps(f)
def batched_wrapper(self, *args, batch_size: Optional[int] = None, **kwargs):
def get_batched(bs: Optional[int], seq: Sequence[Any]) -> list[list[Any]]:
def batched_wrapper(self, *args, batch_size: int | None = None, **kwargs):
def get_batched(bs: int | None, seq: Sequence[Any]) -> list[list[Any]]:
if isinstance(seq, str):
seq = [seq]
if isinstance(seq, list):
return [seq[i : i + bs] for i in range(0, len(seq), bs)] # noqa
if isinstance(seq, tuple):
return list(zip(*[get_batched(bs, s) for s in seq]))
return list(zip(*[get_batched(bs, s) for s in seq], strict=False))
elif isinstance(seq, TensorWrapper):
return [seq.slice_batch(slice(i, i + bs)) for i in range(0, len(seq), bs)] # noqa
else:
Expand All @@ -75,7 +75,9 @@ def get_batched(bs: Optional[int], seq: Sequence[Any]) -> list[list[Any]]:
len_batches = len(batched_args[0])
assert all(len(batch) == len_batches for batch in batched_args)
output = []
zipped_batched_args = zip(*batched_args) if len(batched_args) > 1 else [(x,) for x in batched_args[0]]
zipped_batched_args = (
zip(*batched_args, strict=False) if len(batched_args) > 1 else [(x,) for x in batched_args[0]]
)
for i, batch in enumerate(zipped_batched_args):
logger.debug(f"Batching enabled: processing batch {i + 1} of {len_batches}...")
out = f(self, *batch, **kwargs)
Expand Down
25 changes: 13 additions & 12 deletions inseq/attr/feat/attribution_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import math
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from ...utils import extract_signature_args, get_aligned_idx
from ...utils.typing import (
Expand All @@ -24,8 +25,8 @@
def tok2string(
attribution_model: "AttributionModel",
token_lists: OneOrMoreTokenSequences,
start: Optional[int] = None,
end: Optional[int] = None,
start: int | None = None,
end: int | None = None,
as_targets: bool = True,
) -> TextInput:
"""Enables bounded tokenization of a list of lists of tokens with start and end positions."""
Expand All @@ -42,14 +43,14 @@ def rescale_attributions_to_tokens(
) -> OneOrMoreAttributionSequences:
return [
attr[: len(tokens)] if not all(math.isnan(x) for x in attr) else []
for attr, tokens in zip(attributions, tokens)
for attr, tokens in zip(attributions, tokens, strict=False)
]


def check_attribute_positions(
max_length: int,
attr_pos_start: Optional[int] = None,
attr_pos_end: Optional[int] = None,
attr_pos_start: int | None = None,
attr_pos_end: int | None = None,
) -> tuple[int, int]:
r"""Checks whether the combination of start/end positions for attribution is valid.
Expand Down Expand Up @@ -88,8 +89,8 @@ def check_attribute_positions(
def join_token_ids(
tokens: OneOrMoreTokenSequences,
ids: OneOrMoreIdSequences,
contrast_tokens: Optional[OneOrMoreTokenSequences] = None,
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
contrast_tokens: OneOrMoreTokenSequences | None = None,
contrast_targets_alignments: list[list[tuple[int, int]]] | None = None,
) -> list[TokenWithId]:
"""Joins tokens and ids into a list of TokenWithId objects."""
if contrast_tokens is None:
Expand All @@ -99,10 +100,10 @@ def join_token_ids(
contrast_targets_alignments = [[(idx, idx) for idx, _ in enumerate(seq)] for seq in tokens]
sequences = []
for target_tokens_seq, contrast_target_tokens_seq, input_ids_seq, alignments_seq in zip(
tokens, contrast_tokens, ids, contrast_targets_alignments
tokens, contrast_tokens, ids, contrast_targets_alignments, strict=False
):
curr_seq = []
for pos_idx, (token, token_idx) in enumerate(zip(target_tokens_seq, input_ids_seq)):
for pos_idx, (token, token_idx) in enumerate(zip(target_tokens_seq, input_ids_seq, strict=False)):
contrast_pos_idx = get_aligned_idx(pos_idx, alignments_seq)
if contrast_pos_idx != -1 and token != contrast_target_tokens_seq[contrast_pos_idx]:
curr_seq.append(TokenWithId(f"{contrast_target_tokens_seq[contrast_pos_idx]}{token}", -1))
Expand Down Expand Up @@ -142,10 +143,10 @@ def extract_args(


def get_source_target_attributions(
attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]],
attr: StepAttributionTensor | tuple[StepAttributionTensor, StepAttributionTensor],
is_encoder_decoder: bool,
has_sequence_scores: bool = False,
) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]:
) -> tuple[StepAttributionTensor | None, StepAttributionTensor | None]:
if isinstance(attr, tuple):
if is_encoder_decoder:
if has_sequence_scores:
Expand Down
23 changes: 12 additions & 11 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
* 🟡: Allow custom arguments for model loading in the :class:`FeatureAttribution` :meth:`load` method.
"""
import logging
from collections.abc import Callable
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional

import torch
from jaxtyping import Int
Expand Down Expand Up @@ -123,7 +124,7 @@ def load(
cls,
method_name: str,
attribution_model: Optional["AttributionModel"] = None,
model_name_or_path: Optional[ModelIdentifier] = None,
model_name_or_path: ModelIdentifier | None = None,
**kwargs,
) -> "FeatureAttribution":
r"""Load the selected method and hook it to an existing or available
Expand Down Expand Up @@ -168,16 +169,16 @@ def prepare_and_attribute(
self,
sources: FeatureAttributionInput,
targets: FeatureAttributionInput,
attr_pos_start: Optional[int] = None,
attr_pos_end: Optional[int] = None,
attr_pos_start: int | None = None,
attr_pos_end: int | None = None,
show_progress: bool = True,
pretty_progress: bool = True,
output_step_attributions: bool = False,
attribute_target: bool = False,
step_scores: list[str] = [],
include_eos_baseline: bool = False,
skip_special_tokens: bool = False,
attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None,
attributed_fn: str | Callable[..., SingleScorePerStepTensor] | None = None,
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
Expand Down Expand Up @@ -317,7 +318,7 @@ def format_contrastive_targets(
attr_pos_start: int,
attr_pos_end: int,
skip_special_tokens: bool = False,
) -> tuple[Optional[DecoderOnlyBatch], Optional[list[list[tuple[int, int]]]], dict[str, Any], dict[str, Any]]:
) -> tuple[DecoderOnlyBatch | None, list[list[tuple[int, int]]] | None, dict[str, Any], dict[str, Any]]:
contrast_batch, contrast_targets_alignments = None, None
contrast_targets = attributed_fn_args.get("contrast_targets", None)
if contrast_targets is None:
Expand Down Expand Up @@ -357,10 +358,10 @@ def format_contrastive_targets(

def attribute(
self,
batch: Union[DecoderOnlyBatch, EncoderDecoderBatch],
batch: DecoderOnlyBatch | EncoderDecoderBatch,
attributed_fn: Callable[..., SingleScorePerStepTensor],
attr_pos_start: Optional[int] = None,
attr_pos_end: Optional[int] = None,
attr_pos_start: int | None = None,
attr_pos_end: int | None = None,
show_progress: bool = True,
pretty_progress: bool = True,
output_step_attributions: bool = False,
Expand Down Expand Up @@ -545,10 +546,10 @@ def attribute(

def filtered_attribute_step(
self,
batch: Union[DecoderOnlyBatch, EncoderDecoderBatch],
batch: DecoderOnlyBatch | EncoderDecoderBatch,
target_ids: Int[torch.Tensor, "batch_size 1"],
attributed_fn: Callable[..., SingleScorePerStepTensor],
target_attention_mask: Optional[Int[torch.Tensor, "batch_size 1"]] = None,
target_attention_mask: Int[torch.Tensor, "batch_size 1"] | None = None,
attribute_target: bool = False,
step_scores: list[str] = [],
attribution_args: dict[str, Any] = {},
Expand Down
8 changes: 4 additions & 4 deletions inseq/attr/feat/internals_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Attention-based feature attribution methods."""

import logging
from typing import Any, Optional
from typing import Any

from captum._utils.typing import TensorOrTupleOfTensorsGeneric

Expand Down Expand Up @@ -46,9 +46,9 @@ def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
additional_forward_args: TensorOrTupleOfTensorsGeneric,
encoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
decoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
cross_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None,
encoder_self_attentions: MultiLayerMultiUnitScoreTensor | None = None,
decoder_self_attentions: MultiLayerMultiUnitScoreTensor | None = None,
cross_attentions: MultiLayerMultiUnitScoreTensor | None = None,
) -> MultiDimensionalFeatureAttributionStepOutput:
"""Extracts the attention weights from the model.
Expand Down
13 changes: 7 additions & 6 deletions inseq/attr/feat/ops/discretized_integrated_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
# OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, Union
from typing import Any

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -94,9 +95,9 @@ def attribute( # type: ignore
additional_forward_args: Any = None,
n_steps: int = 50,
method: str = "greedy",
internal_batch_size: Union[None, int] = None,
internal_batch_size: None | int = None,
return_convergence_delta: bool = False,
) -> Union[TensorOrTupleOfTensorsGeneric, tuple[TensorOrTupleOfTensorsGeneric, Tensor]]:
) -> TensorOrTupleOfTensorsGeneric | tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
n_examples = inputs[0].shape[0]
# Keeps track whether original input is a tuple or not before
# converting it into a tuple.
Expand All @@ -112,7 +113,7 @@ def attribute( # type: ignore
n_steps=n_steps,
scale_strategy=method,
)
for input_tensor, baseline_tensor in zip(inputs, baselines)
for input_tensor, baseline_tensor in zip(inputs, baselines, strict=False)
)
if internal_batch_size is not None:
attributions = _batch_attribution(
Expand Down Expand Up @@ -181,7 +182,7 @@ def _attribute(
# total_grads has the same dimensionality as the original inputs
total_grads = tuple(
_reshape_and_sum(scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:])
for (scaled_grad, grad) in zip(scaled_grads, grads)
for (scaled_grad, grad) in zip(scaled_grads, grads, strict=False)
)
# computes attribution for each tensor in input_tuple
# attributions has the same dimensionality as the original inputs
Expand All @@ -191,5 +192,5 @@ def _attribute(
inputs, baselines = self.get_inputs_baselines(scaled_features_tpl, n_steps)
return tuple(
total_grad * (input - baseline)
for (total_grad, input, baseline) in zip(total_grads, inputs, baselines)
for (total_grad, input, baseline) in zip(total_grads, inputs, baselines, strict=False)
)
Loading

0 comments on commit 904c893

Please sign in to comment.