Skip to content

Commit

Permalink
Fix MultiDimensional attribution aggregation (inseq-team#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti authored Nov 10, 2024
1 parent 97e5021 commit 1078e19
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ out_sliced = out.aggregate("slices", target_spans=(13,73))
out_sliced = out[13:73]
```

- A new `StringSplitAggregator` (`"split"`) is added to allow for supporting more complex aggregation procedures beyond simple subword merging in`FeatureAttributionSequenceOutput` objects. More specifically, splitting supports regex expression to match split points even when these are (potentially overlapping) parts of existing tokens. The `split_mode` parameter can be set to `"single"` (default) to keep tokens containing matched split points separate while aggregating the rest, or `"start"` or `"end"` to concatenate them to the preceding/following aggregated token sequence. [#290](https://github.com/inseq-team/inseq/pull/290)

```python
# Split on newlines. Default split_mode = "single".
out.aggregate("split", split_pattern="\n").aggregate("sum").show(do_aggregation=False)

# Split on whitespace-separated words of length 5.
# Note: this works if clean_special_chars = True is used, otherwise the split_pattern should be adjusted to split on special characters like "Ġ" or "▁".
out.aggregate("split", split_pattern=r"\s(\w{5})(?=\s)", split_mode="end")
```

- The `__sub__` method in `FeatureAttributionSequenceOutput` is now used as a shortcut for `PairAggregator` ([#282](https://github.com/inseq-team/inseq/pull/282)).


Expand Down Expand Up @@ -90,6 +101,8 @@ out_female = attrib_model.attribute(

- Fix support for multi-EOS tokens (e.g. LLaMA 3.2, see [#287](https://github.com/inseq-team/inseq/issues/287)).

- Fix copying configuration parameters to aggregated `FeatureAttributionSequenceOutput` objects ([#292](https://github.com/inseq-team/inseq/pull/292)).

## 📝 Documentation and Tutorials

- Updated tutorial with `treescope` usage examples.
Expand Down
8 changes: 7 additions & 1 deletion inseq/data/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _aggregate(
kwargs["aggregate_fn"] = kwargs["aggregate_fn"][cls.aggregator_family]
field_func = getattr(cls, f"aggregate_{field}")
aggregated_sequence_attribution_fields[field] = field_func(attr, **kwargs)
return attr.__class__(**aggregated_sequence_attribution_fields)
return attr.__class__(**aggregated_sequence_attribution_fields, **attr.config)

@classmethod
def _process_attribution_scores(
Expand Down Expand Up @@ -346,6 +346,12 @@ def _process_attribution_scores(
@classmethod
def post_aggregate_hook(cls, attr: "FeatureAttributionSequenceOutput", **kwargs):
super().post_aggregate_hook(attr, **kwargs)
if attr.source_attributions is not None:
attr._num_dimensions = attr.source_attributions.ndim
elif attr.target_attributions is not None:
attr._num_dimensions = attr.target_attributions.ndim
else:
attr._num_dimensions = 0
cls.is_compatible(attr)

@classmethod
Expand Down
13 changes: 10 additions & 3 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class FeatureAttributionSequenceOutput(TensorWrapper, AggregableMixin):
_aggregator: str | list[str] | None = None
_dict_aggregate_fn: dict[str, str] | None = None
_attribution_dim_names: dict[str, dict[int, str]] | None = None
_num_dimensions: int | None = None

def __post_init__(self):
if self._dict_aggregate_fn is None:
Expand All @@ -181,6 +182,8 @@ def __post_init__(self):
self._attribution_dim_names = default_dim_names
if self._aggregator is None:
self._aggregator = "scores"
if self._num_dimensions is None:
self._num_dimensions = 0
if self.attr_pos_end is None or self.attr_pos_end > len(self.target):
self.attr_pos_end = len(self.target)

Expand Down Expand Up @@ -309,6 +312,10 @@ def _recover_from_safetensors(self):
}
return self

@property
def config(self) -> dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if k.startswith("_")}

@staticmethod
def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable:
if attr.source_attributions is None or name.startswith("decoder"):
Expand Down Expand Up @@ -696,10 +703,13 @@ class FeatureAttributionStepOutput(TensorWrapper):
source: OneOrMoreTokenWithIdSequences | None = None
prefix: OneOrMoreTokenWithIdSequences | None = None
target: OneOrMoreTokenWithIdSequences | None = None
_num_dimensions: int | None = None
_sequence_cls: type["FeatureAttributionSequenceOutput"] = FeatureAttributionSequenceOutput

def __post_init__(self):
self.to(torch.float32)
if self._num_dimensions is None:
self._num_dimensions = 0
if self.step_scores is None:
self.step_scores = {}
if self.sequence_scores is None:
Expand Down Expand Up @@ -1213,8 +1223,6 @@ class MultiDimensionalFeatureAttributionSequenceOutput(FeatureAttributionSequenc
attention head and per layer for every source-target token pair in the source attributions (i.e. 2 dimensions).
"""

_num_dimensions: int = 2

def __post_init__(self):
super().__post_init__()
self._aggregator = ["mean"] * self._num_dimensions
Expand All @@ -1233,7 +1241,6 @@ def __post_init__(self):
class MultiDimensionalFeatureAttributionStepOutput(FeatureAttributionStepOutput):
"""Raw output of a single step of multi-dimensional feature attribution."""

_num_dimensions: int = 2
_sequence_cls: type["FeatureAttributionSequenceOutput"] = MultiDimensionalFeatureAttributionSequenceOutput

def get_sequence_cls(self, **kwargs):
Expand Down

0 comments on commit 1078e19

Please sign in to comment.