Skip to content

Commit

Permalink
Improve attribution viz and SliceAggregator (#282)
Browse files Browse the repository at this point in the history
* Improve visualization, allow slicing of attributions

* Added SliceAggregator and dunder shortcuts

* Minor fixes

* Fix safety, add changelogs
  • Loading branch information
gsarti authored Jul 3, 2024
1 parent 01bd08b commit 5a46d51
Show file tree
Hide file tree
Showing 10 changed files with 398 additions and 37 deletions.
47 changes: 47 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,50 @@

- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM` to model config.

- Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment [#282](https://github.com/inseq-team/inseq/pull/282)

- - A new `SliceAggregator` (`"slices"`) is added to allow for slicing source (in encoder-decoder) or target (in decoder-only) tokens from a `FeatureAttributionSequenceOutput` object, using the same syntax of `ContiguousSpanAggregator`. The `__getitem__` method of the `FeatureAttributionSequenceOutput` is a shortcut for this, allowing slicing with `[start:stop]` syntax. [#282](https://github.com/inseq-team/inseq/pull/282)

```python
import inseq
from inseq.data.aggregator import SliceAggregator

attrib_model = inseq.load_model("gpt2", "attention")
input_prompt = """Instruction: Summarize this article.
Input_text: In a quiet village nestled between rolling hills, an ancient tree whispered secrets to those who listened. One night, a curious child named Elara leaned close and heard tales of hidden treasures beneath the roots. As dawn broke, she unearthed a shimmering box, unlocking a forgotten world of wonder and magic.
Summary:"""

full_output_prompt = input_prompt + " Elara discovers a shimmering box under an ancient tree, unlocking a world of magic."

out = attrib_model.attribute(input_prompt, full_output_prompt)[0]

# These are all equivalent ways to slice only the input text contents
out_sliced = out.aggregate(SliceAggregator, target_spans=(13,73))
out_sliced = out.aggregate("slices", target_spans=(13,73))
out_sliced = out[13:73]
```

- The `__sub__` method in `FeatureAttributionSequenceOutput` is now used as a shortcut for `PairAggregator` [#282](https://github.com/inseq-team/inseq/pull/282)


```python
import inseq

attrib_model = inseq.load_model("gpt2", "saliency")

out_male = attrib_model.attribute(
"The director went home because",
"The director went home because he was tired",
step_scores=["probability"]
)[0]
out_female = attrib_model.attribute(
"The director went home because",
"The director went home because she was tired",
step_scores=["probability"]
)[0]
(out_male - out_female).show()
```

## 🔧 Fixes and Refactoring

- Fix the issue in the attention implementation from [#268](https://github.com/inseq-team/inseq/issues/268) where non-terminal position in the tensor were set to nan if they were 0s ([#269](https://github.com/inseq-team/inseq/pull/269)).
Expand All @@ -14,6 +58,9 @@

- Fix bug reported in [#266](https://github.com/inseq-team/inseq/issues/266) making `value_zeroing` unusable for SDPA attention. This enables using the method on models using SDPA attention as default (e.g. `GemmaForCausalLM`) without passing `model_kwargs={'attn_implementation': 'eager'}` ([#267](https://github.com/inseq-team/inseq/pull/267)).

- The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y [#282](https://github.com/inseq-team/inseq/pull/282)


## 📝 Documentation and Tutorials

*No changes*
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ fix-style:

.PHONY: check-safety
check-safety:
$(PYTHON) -m safety check --full-report
$(PYTHON) -m safety check --full-report -i 70612

.PHONY: lint
lint: fix-style check-safety
Expand Down
143 changes: 138 additions & 5 deletions inseq/data/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,9 @@ def end_aggregation_hook(cls, attr: "FeatureAttributionSequenceOutput", **kwargs
assert attr.target_attributions.ndim == 2, attr.target_attributions.shape
except AssertionError as e:
raise RuntimeError(
f"The aggregated attributions should be 2-dimensional to be visualized. Found dimensions: {e.args[0]}"
"If you're performing intermediate aggregation and don't aim to visualize the output right away, use"
"do_post_aggregation_checks=False in the aggregate method to bypass this check."
f"The aggregated attributions should be 2-dimensional to be visualized.\nFound dimensions: {e.args[0]}"
"\n\nIf you're performing intermediate aggregation and don't aim to visualize the output right away, "
"use do_post_aggregation_checks=False in the aggregate method to bypass this check."
) from e

@staticmethod
Expand Down Expand Up @@ -530,7 +530,7 @@ def format_spans(spans) -> list[tuple[int, int]]:
return [spans] if isinstance(spans[0], int) else spans

@classmethod
def validate_spans(cls, span_sequence: "FeatureAttributionSequenceOutput", spans: Optional[IndexSpan] = None):
def validate_spans(cls, span_sequence: list[TokenWithId], spans: Optional[IndexSpan] = None):
if not spans:
return
allmatch = lambda l, type: all(isinstance(x, type) for x in l)
Expand All @@ -545,7 +545,7 @@ def validate_spans(cls, span_sequence: "FeatureAttributionSequenceOutput", spans
assert (
span[0] >= prev_span_max
), f"Spans must be postive-valued, non-overlapping and in ascending order, got {spans}"
assert span[1] < len(span_sequence), f"Span values must be indexes of the original span, got {spans}"
assert span[1] <= len(span_sequence), f"Span values must be indexes of the original span, got {spans}"
prev_span_max = span[1]

@staticmethod
Expand Down Expand Up @@ -808,3 +808,136 @@ def aggregate_sequence_scores(attr, paired_attr, aggregate_fn, **kwargs):
agg_fn = aggregate_fn[name] if isinstance(aggregate_fn, dict) else aggregate_fn
out_dict[name] = agg_fn(sequence_scores, paired_attr.sequence_scores[name])
return out_dict


class SliceAggregator(ContiguousSpanAggregator):
"""Slices the FeatureAttributionSequenceOutput object into a smaller one containing a subset of its elements.
Args:
attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The starting attribution object.
source_spans (tuple of [int, int] or sequence of tuples of [int, int], optional): Spans to slice for the
source sequence. Defaults to no slicing performed.
target_spans (tuple of [int, int] or sequence of tuples of [int, int], optional): Spans to slice for the
target sequence. Defaults to no slicing performed.
"""

aggregator_name = "slices"
default_fn = None

@classmethod
def aggregate(
cls,
attr: "FeatureAttributionSequenceOutput",
source_spans: Optional[IndexSpan] = None,
target_spans: Optional[IndexSpan] = None,
**kwargs,
):
"""Spans can be:
1. A list of the form [pos_start, pos_end] including the contiguous positions of tokens that
are to be aggregated, if all values are integers and len(span) < len(original_seq)
2. A list of the form [(pos_start_0, pos_end_0), (pos_start_1, pos_end_1)], same as above but
for multiple contiguous spans.
"""
source_spans = cls.format_spans(source_spans)
target_spans = cls.format_spans(target_spans)

if attr.source_attributions is None:
if source_spans is not None:
logger.warn(
"Source spans are specified but no source scores are given for decoder-only models. "
"Ignoring source spans and using target spans instead."
)
source_spans = [(s[0], min(s[1], attr.attr_pos_start)) for s in target_spans]

# Generated tokens are always included in the slices to preserve the output scores
is_gen_added = False
new_target_spans = []
if target_spans is not None:
for span in target_spans:
if span[1] > attr.attr_pos_start and is_gen_added:
continue
elif span[1] > attr.attr_pos_start and not is_gen_added:
new_target_spans.append((span[0], attr.attr_pos_end))
is_gen_added = True
else:
new_target_spans.append(span)
if not is_gen_added:
new_target_spans.append((attr.attr_pos_start, attr.attr_pos_end))
return super().aggregate(attr, source_spans=source_spans, target_spans=new_target_spans, **kwargs)

@staticmethod
def aggregate_source(attr: "FeatureAttributionSequenceOutput", source_spans: list[tuple[int, int]], **kwargs):
sliced_source = []
for span in source_spans:
sliced_source.extend(attr.source[span[0] : span[1]])
return sliced_source

@staticmethod
def aggregate_target(attr: "FeatureAttributionSequenceOutput", target_spans: list[tuple[int, int]], **kwargs):
sliced_target = []
for span in target_spans:
sliced_target.extend(attr.target[span[0] : span[1]])
return sliced_target

@staticmethod
def aggregate_source_attributions(attr: "FeatureAttributionSequenceOutput", source_spans, **kwargs):
if attr.source_attributions is None:
return attr.source_attributions
return torch.cat(
tuple(attr.source_attributions[span[0] : span[1], ...] for span in source_spans),
dim=0,
)

@staticmethod
def aggregate_target_attributions(attr: "FeatureAttributionSequenceOutput", target_spans, **kwargs):
if attr.target_attributions is None:
return attr.target_attributions
return torch.cat(
tuple(attr.target_attributions[span[0] : span[1], ...] for span in target_spans),
dim=0,
)

@staticmethod
def aggregate_step_scores(attr: "FeatureAttributionSequenceOutput", **kwargs):
return attr.step_scores

@classmethod
def aggregate_sequence_scores(
cls,
attr: "FeatureAttributionSequenceOutput",
source_spans,
target_spans,
**kwargs,
):
if not attr.sequence_scores:
return attr.sequence_scores
out_dict = {}
for name, step_scores in attr.sequence_scores.items():
if name.startswith("decoder"):
out_dict[name] = torch.cat(
tuple(step_scores[span[0] : span[1], ...] for span in target_spans),
dim=0,
)
elif name.startswith("encoder"):
out_dict[name] = torch.cat(
tuple(step_scores[span[0] : span[1], span[0] : span[1], ...] for span in source_spans),
dim=0,
)
else:
out_dict[name] = torch.cat(
tuple(step_scores[span[0] : span[1], ...] for span in source_spans),
dim=0,
)
return out_dict

@staticmethod
def aggregate_attr_pos_start(attr: "FeatureAttributionSequenceOutput", target_spans, **kwargs):
if not target_spans:
return attr.attr_pos_start
tot_sliced_len = sum(min(s[1], attr.attr_pos_start) - s[0] for s in target_spans)
return tot_sliced_len

@staticmethod
def aggregate_attr_pos_end(attr: "FeatureAttributionSequenceOutput", **kwargs):
return attr.attr_pos_end
10 changes: 10 additions & 0 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def __post_init__(self):
if self.attr_pos_end is None or self.attr_pos_end > len(self.target):
self.attr_pos_end = len(self.target)

def __getitem__(self, s: Union[slice, int]) -> "FeatureAttributionSequenceOutput":
source_spans = None if self.source_attributions is None else (s.start, s.stop)
target_spans = None if self.source_attributions is not None else (s.start, s.stop)
return self.aggregate("slices", source_spans=source_spans, target_spans=target_spans)

def __sub__(self, other: "FeatureAttributionSequenceOutput") -> "FeatureAttributionSequenceOutput":
if not isinstance(other, self.__class__):
raise ValueError(f"Cannot compare {type(other)} with {type(self)}")
return self.aggregate("pair", paired_attr=other, do_post_aggregation_checks=False)

@staticmethod
def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable:
if attr.source_attributions is None or name.startswith("decoder"):
Expand Down
23 changes: 15 additions & 8 deletions inseq/data/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,16 @@ def get_saliency_heatmap_html(
uuid = "".join(random.choices(string.ascii_lowercase, k=20))
out = saliency_heatmap_table_header
# add top row containing target tokens
out += "<tr><th></th><th></th>"
for column_idx in range(len(column_labels)):
out += f"<th>{column_idx}</th>"
out += "</tr><tr><th></th><th></th>"
for column_label in column_labels:
out += f"<th>{sanitize_html(column_label)}</th>"
out += "</tr>"
if scores is not None:
for row_index in range(scores.shape[0]):
out += f"<tr><th>{sanitize_html(row_labels[row_index])}</th>"
out += f"<tr><th>{row_index}</th><th>{sanitize_html(row_labels[row_index])}</th>"
for col_index in range(scores.shape[1]):
score = ""
if not np.isnan(scores[row_index, col_index]):
Expand All @@ -223,7 +227,7 @@ def get_saliency_heatmap_html(
out += "</tr>"
if step_scores is not None:
for step_score_name, step_score_values in step_scores.items():
out += f'<tr style="outline: thin solid"><th><b>{step_score_name}</b></th>'
out += f'<tr style="outline: thin solid"><th></th><th><b>{step_score_name}</b></th>'
if isinstance(step_scores_threshold, float):
threshold = step_scores_threshold
else:
Expand Down Expand Up @@ -254,20 +258,23 @@ def get_saliency_heatmap_rich(
label: str = "",
step_scores_threshold: Union[float, dict[str, float]] = 0.5,
):
columns = [Column(header="", justify="right", overflow="fold")]
for column_label in column_labels:
columns.append(Column(header=escape(column_label), justify="center", overflow="fold"))
columns = [
Column(header="", justify="right", overflow="fold"),
Column(header="", justify="right", overflow="fold"),
]
for idx, column_label in enumerate(column_labels):
columns.append(Column(header=f"{idx}\n{escape(column_label)}", justify="center", overflow="fold"))
table = Table(
*columns,
title=f"{label + ' ' if label else ''}Saliency Heatmap",
caption="x: Generated tokens, y: Attributed tokens",
caption=": Generated tokens, : Attributed tokens",
padding=(0, 1, 0, 1),
show_lines=False,
box=box.HEAVY_HEAD,
)
if scores is not None:
for row_index in range(scores.shape[0]):
row = [Text(escape(row_labels[row_index]), style="bold")]
row = [Text(f"{row_index}", style="bold"), Text(escape(row_labels[row_index]), style="bold")]
for col_index in range(scores.shape[1]):
color = Color.from_rgb(*input_colors[row_index][col_index])
score = ""
Expand All @@ -282,7 +289,7 @@ def get_saliency_heatmap_rich(
else:
threshold = step_scores_threshold.get(step_score_name, 0.5)
style = lambda val, limit: "bold" if abs(val) >= limit and isinstance(val, float) else ""
score_row = [Text(escape(step_score_name), style="bold")]
score_row = [Text(""), Text(escape(step_score_name), style="bold")]
for score in step_score_values:
curr_score = round(score.item(), 2) if isinstance(score, float) else score.item()
score_row.append(Text(f"{score:.2f}", justify="center", style=style(curr_score, threshold)))
Expand Down
3 changes: 1 addition & 2 deletions inseq/utils/viz_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,14 @@ def get_colors(
saliency_heatmap_table_header = """
<table border="1" cellpadding="5" cellspacing="5"
style="overflow-x:scroll;display:block;">
<tr><th></th>
"""

saliency_heatmap_html = """
<div id="{uuid}_saliency_plot" class="{uuid}_viz_content">
<div style="margin:5px;font-family:sans-serif;font-weight:bold;">
<span style="font-size: 20px;">{label} Saliency Heatmap</span>
<br>
x: Generated tokens, y: Attributed tokens
: Generated tokens, : Attributed tokens
</div>
{content}
</div>
Expand Down
Loading

0 comments on commit 5a46d51

Please sign in to comment.