Skip to content

Commit

Permalink
Update tutorial to contrastive attribution changes (#231)
Browse files Browse the repository at this point in the history
* Remove weight_attribution from examples

* Fix tutorial
  • Loading branch information
gsarti authored Nov 1, 2023
1 parent 33f6932 commit edfd5e3
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 375 deletions.
2 changes: 0 additions & 2 deletions docs/source/examples/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@ Inseq allows users to specify custom attribution targets using the ``attributed_
step_scores=["contrast_prob_diff"]
)
# Weight attribution scores by the difference in probabilities
out.weight_attributions("contrast_prob_diff")
out.show()
.. raw:: html
Expand Down
28 changes: 14 additions & 14 deletions docs/source/html_outputs/contrastive_example.htm

Large diffs are not rendered by default.

Binary file modified docs/source/images/cat_example_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
643 changes: 329 additions & 314 deletions examples/inseq_tutorial.ipynb

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions inseq/attr/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..data import FeatureAttributionInput
from ..data.aggregation_functions import DEFAULT_ATTRIBUTION_AGGREGATE_DICT
from ..utils import extract_signature_args, filter_logits, top_p_logits_mask
from ..utils.contrast_utils import _get_contrast_output, _setup_contrast_args, contrast_fn_docstring
from ..utils.contrast_utils import _get_contrast_inputs, _setup_contrast_args, contrast_fn_docstring
from ..utils.typing import EmbeddingsTensor, IdsTensor, SingleScorePerStepTensor, TargetIdsTensor

if TYPE_CHECKING:
Expand Down Expand Up @@ -225,14 +225,17 @@ def kl_divergence_fn(
"method used. Use --contrast_force_inputs in the model.attribute call to proceed."
)
original_logits: torch.Tensor = args.attribution_model.output2logits(args.forward_output)
contrast_output = _get_contrast_output(
contrast_inputs = _get_contrast_inputs(
args=args,
contrast_sources=contrast_sources,
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
return_contrastive_target_ids=False,
)
contrast_logits: torch.Tensor = args.attribution_model.output2logits(contrast_output.forward_output)
c_forward_output = args.attribution_model.get_forward_output(
contrast_inputs.batch, use_embeddings=args.attribution_model.is_encoder_decoder
)
contrast_logits: torch.Tensor = args.attribution_model.output2logits(c_forward_output)
filtered_original_logits, filtered_contrast_logits = filter_logits(
original_logits=original_logits,
contrast_logits=contrast_logits,
Expand Down
3 changes: 1 addition & 2 deletions inseq/data/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def get_heatmap_type(
)
elif heatmap_type == "Target":
if attribution.target_attributions is not None:
mask = np.where(attribution.target_attributions.numpy() == 0, float("nan"), 0)
target_attributions = attribution.target_attributions.numpy() + mask
target_attributions = attribution.target_attributions.numpy()
else:
target_attributions = None
return heatmap_func(
Expand Down
49 changes: 31 additions & 18 deletions inseq/utils/alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ def get_adjusted_alignments(
).alignments
alignments = [(a_idx, b_idx) for a_idx, b_idx in alignments if start_pos <= a_idx < end_pos]
is_auto_aligned = True
logger.warning(
f"Using {ALIGN_MODEL_ID} for automatic alignments. Provide custom alignments for non-linguistic "
f"sequences, or for languages not covered by the aligner."
)
else:
raise ValueError(
f"Unknown alignment method: {alignments}. "
Expand All @@ -310,11 +314,21 @@ def get_adjusted_alignments(
# Sort alignments
alignments = sorted(set(alignments), key=lambda x: (x[0], x[1]))

# Filter alignments (restrict to one per token)
filter_aligns = []
for pair_idx in range(start_pos, end_pos):
match_pairs = [(p0, p1) for p0, p1 in alignments if p0 == pair_idx and 0 <= p1 < len(contrast_tokens)]
if match_pairs:
# If found, use the first match that containing an unaligned target token, first match otherwise
match_pairs_unaligned = [p for p in match_pairs if p[1] not in [f[1] for f in filter_aligns]]
valid_match = match_pairs_unaligned[0] if match_pairs_unaligned else match_pairs[0]
filter_aligns.append(valid_match)

# Filling alignments with missing tokens
if fill_missing:
filled_alignments = []
filled_alignments = filter_aligns.copy()
for step_idx, pair_idx in enumerate(reversed(range(start_pos, end_pos)), start=1):
match_pairs = [pair for pair in alignments if pair[0] == pair_idx and 0 <= pair[1] < len(contrast_tokens)]
match_pairs = [(p0, p1) for p0, p1 in filter_aligns if p0 == pair_idx and 0 <= p1 < len(contrast_tokens)]

# Default behavior: fill missing alignments with 1:1 position alignments starting from the bottom of the
# two sequences
Expand All @@ -323,24 +337,23 @@ def get_adjusted_alignments(
filled_alignments.append((pair_idx, len(contrast_tokens) - 1))
else:
filled_alignments.append((pair_idx, len(contrast_tokens) - step_idx))
else:
# If found, use the first match that containing an unaligned target token, first match otherwise
match_pairs_unaligned = [p for p in match_pairs if p[1] not in [f[1] for f in filled_alignments]]
valid_match = match_pairs_unaligned[0] if match_pairs_unaligned else match_pairs[0]
filled_alignments.append(valid_match)
if alignments != filled_alignments:
alignments = filled_alignments

if filter_aligns != filled_alignments:
existing_aligns_message = (
f"Provided target alignments do not cover all {end_pos - start_pos} tokens from the original sequence."
)
no_aligns_message = (
"No target alignments were provided for the contrastive target. "
"Use e.g. 'contrast_targets_alignments=[(0,1), ...] to provide them in model.attribute"
)
logger.warning(
f"Provided alignments do not cover all {end_pos - start_pos} tokens from the original"
" sequence.\nFilling missing position with right-aligned 1:1 position alignments.\n"
f"Generated alignments: {alignments}"
f"{existing_aligns_message if filter_aligns else no_aligns_message}\n"
"Filling missing position with right-aligned 1:1 position alignments."
)
if is_auto_aligned:
logger.warning(
f"Using {ALIGN_MODEL_ID} for automatic alignments. Provide custom alignments for non-linguistic "
f"sequences, or for languages not covered by the aligner.\nGenerated alignments: {alignments}"
)
return alignments
filter_aligns = sorted(set(filled_alignments), key=lambda x: (x[0], x[1]))
if is_auto_aligned or (fill_missing and filter_aligns != filled_alignments):
logger.warning(f"Generated alignments: {filter_aligns}")
return filter_aligns


def get_aligned_idx(a_idx: int, alignments: List[Tuple[int, int]]) -> int:
Expand Down
54 changes: 34 additions & 20 deletions inseq/utils/contrast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union

from transformers.modeling_outputs import ModelOutput
import torch

from ..data import (
DecoderOnlyBatch,
Expand Down Expand Up @@ -52,21 +52,20 @@ def docstring_decorator(fn: "StepFunction") -> "StepFunction":


@dataclass
class ContrastOutput:
forward_output: ModelOutput
class ContrastInputs:
batch: Union[EncoderDecoderBatch, DecoderOnlyBatch, None] = None
target_ids: Optional[TargetIdsTensor] = None


def _get_contrast_output(
def _get_contrast_inputs(
args: "StepFunctionArgs",
contrast_sources: Optional[FeatureAttributionInput] = None,
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
return_contrastive_target_ids: bool = False,
return_contrastive_batch: bool = False,
**forward_kwargs,
) -> ContrastOutput:
) -> ContrastInputs:
"""Utility function to return the output of the model for given contrastive inputs.
Args:
Expand All @@ -93,10 +92,13 @@ def _get_contrast_output(
f" ({args.decoder_input_ids.size(0)}). Multi-sentence attribution and methods expanding inputs to"
" multiple steps (e.g. Integrated Gradients) are not yet supported for contrastive attribution."
)

args.decoder_input_ids = c_batch.target_ids
args.decoder_input_embeds = c_batch.target_embeds
args.decoder_attention_mask = c_batch.target_mask
if (
args.decoder_input_ids.shape != c_batch.target_ids.shape
or torch.ne(args.decoder_input_ids, c_batch.target_ids).any()
):
args.decoder_input_ids = c_batch.target_ids
args.decoder_input_embeds = c_batch.target_embeds
args.decoder_attention_mask = c_batch.target_mask
if contrast_sources:
from ..attr.step_functions import StepFunctionEncoderDecoderArgs

Expand All @@ -106,12 +108,15 @@ def _get_contrast_output(
"Use `contrast_targets` to set a contrastive target containing a prefix for decoder-only models."
)
c_enc_in = args.attribution_model.encode(contrast_sources)
args.encoder_input_ids = c_enc_in.input_ids
args.encoder_attention_mask = c_enc_in.attention_mask
args.encoder_input_embeds = args.attribution_model.embed(args.encoder_input_ids, as_targets=False)
if (
args.encoder_input_ids.shape != c_enc_in.input_ids.shape
or torch.ne(args.encoder_input_ids, c_enc_in.input_ids).any()
):
args.encoder_input_ids = c_enc_in.input_ids
args.encoder_input_embeds = args.attribution_model.embed(args.encoder_input_ids, as_targets=False)
args.encoder_attention_mask = c_enc_in.attention_mask
c_batch = args.attribution_model.formatter.convert_args_to_batch(args)
return ContrastOutput(
forward_output=args.attribution_model.get_forward_output(c_batch, use_embeddings=is_enc_dec, **forward_kwargs),
return ContrastInputs(
batch=c_batch if return_contrastive_batch else None,
target_ids=c_tgt_ids if return_contrastive_target_ids else None,
)
Expand All @@ -124,24 +129,21 @@ def _setup_contrast_args(
contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None,
contrast_force_inputs: bool = False,
):
c_output = _get_contrast_output(
c_inputs = _get_contrast_inputs(
args,
contrast_sources=contrast_sources,
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
return_contrastive_target_ids=True,
return_contrastive_batch=True,
)
if c_output.target_ids is not None:
args.target_ids = c_output.target_ids
if args.is_attributed_fn:
if contrast_force_inputs:
warnings.warn(
"Forcing contrastive inputs to be used for attribution. This may result in unexpected behavior for "
"gradient-based attribution methods.",
stacklevel=1,
)
args.forward_output = c_output.forward_output
else:
warnings.warn(
"Contrastive inputs do not match original inputs when using a contrastive attributed function.\n"
Expand All @@ -151,6 +153,18 @@ def _setup_contrast_args(
"set --contrast_force_inputs to True to use the contrastive inputs for attribution instead.",
stacklevel=1,
)
use_original_output = args.is_attributed_fn and not contrast_force_inputs
if use_original_output:
forward_output = args.forward_output
else:
args.forward_output = c_output.forward_output
return args
forward_output = args.attribution_model.get_forward_output(
c_inputs.batch, use_embeddings=args.attribution_model.is_encoder_decoder
)
c_args = args.attribution_model.formatter.format_step_function_args(
args.attribution_model,
forward_output=forward_output,
target_ids=c_inputs.target_ids if c_inputs.target_ids is not None else args.target_ids,
batch=c_inputs.batch,
is_attributed_fn=args.is_attributed_fn,
)
return c_args
4 changes: 2 additions & 2 deletions tests/attr/feat/test_step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytest import fixture

import inseq
from inseq.attr.step_functions import StepFunctionArgs, _get_contrast_output, probability_fn
from inseq.attr.step_functions import StepFunctionArgs, _get_contrast_inputs, probability_fn
from inseq.models import DecoderOnlyAttributionModel, EncoderDecoderAttributionModel


Expand Down Expand Up @@ -62,7 +62,7 @@ def attr_prob_diff_fn(
logprob: bool = False,
):
model_probs = probability_fn(args, logprob=logprob)
c_out = _get_contrast_output(
c_out = _get_contrast_inputs(
args,
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
Expand Down

0 comments on commit edfd5e3

Please sign in to comment.