From 6feda957369c9711ef19b28b4170212a66a1be34 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Mon, 30 Oct 2023 11:26:35 +0100 Subject: [PATCH] Attributed behavior for contrastive step functions (#228) * Add contrast utils and is_attributed_fn step function arg * Fix imports --- inseq/attr/feat/feature_attribution.py | 1 + inseq/attr/step_functions.py | 133 +++++---------------- inseq/data/batch.py | 2 + inseq/models/attribution_model.py | 3 +- inseq/models/decoder_only.py | 2 + inseq/models/encoder_decoder.py | 2 + inseq/utils/alignment_utils.py | 9 +- inseq/utils/contrast_utils.py | 156 +++++++++++++++++++++++++ tests/attr/feat/test_step_functions.py | 78 +++++++++++++ 9 files changed, 277 insertions(+), 109 deletions(-) create mode 100644 inseq/utils/contrast_utils.py diff --git a/inseq/attr/feat/feature_attribution.py b/inseq/attr/feat/feature_attribution.py index 6b076828..2994fe13 100644 --- a/inseq/attr/feat/feature_attribution.py +++ b/inseq/attr/feat/feature_attribution.py @@ -586,6 +586,7 @@ def filtered_attribute_step( attribution_model=self.attribution_model, forward_output=output, target_ids=target_ids, + is_attributed_fn=False, batch=batch, ) step_fn_extra_args = get_step_scores_args([step_score], step_scores_args) diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py index 67368e52..423092ed 100644 --- a/inseq/attr/step_functions.py +++ b/inseq/attr/step_functions.py @@ -6,9 +6,10 @@ import torch.nn.functional as F from transformers.modeling_outputs import ModelOutput -from ..data import DecoderOnlyBatch, FeatureAttributionInput, get_batch_from_inputs, slice_batch_from_position +from ..data import FeatureAttributionInput from ..data.aggregation_functions import DEFAULT_ATTRIBUTION_AGGREGATE_DICT from ..utils import extract_signature_args, filter_logits, top_p_logits_mask +from ..utils.contrast_utils import _get_contrast_output, _setup_contrast_args, contrast_fn_docstring from ..utils.typing import EmbeddingsTensor, IdsTensor, SingleScorePerStepTensor, TargetIdsTensor if TYPE_CHECKING: @@ -27,6 +28,9 @@ class StepFunctionBaseArgs: forward_output (:class:`~inseq.models.ModelOutput`): The output of the model's forward pass. target_ids (:obj:`torch.Tensor`): Tensor of target token ids of size :obj:`(batch_size,)` corresponding to the target predicted tokens for the next generation step. + is_attributed_fn (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether the step function is being used + as attribution target. Defaults to :obj:`False`. Enables custom behavior that is different whether the fn + is used as target or not. encoder_input_ids (:obj:`torch.Tensor`): Tensor of ids of encoder input tokens of size :obj:`(batch_size, source_seq_len)`, representing encoder inputs at the present step. Available only for encoder-decoder models. @@ -50,6 +54,7 @@ class StepFunctionBaseArgs: decoder_input_ids: IdsTensor decoder_input_embeds: EmbeddingsTensor decoder_attention_mask: IdsTensor + is_attributed_fn: bool @dataclass @@ -76,36 +81,6 @@ def __call__( ... -CONTRAST_FN_ARGS_DOCSTRING = """Args: - contrast_sources (:obj:`str` or :obj:`list(str)`): Source text(s) used as contrastive inputs to compute - the contrastive step function for encoder-decoder models. If not specified, the source text is assumed to - match the original source text. Defaults to :obj:`None`. - contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original - target text. If not specified, the original target text is used as contrastive target (will result in same - output unless ``contrast_sources`` are specified). Defaults to :obj:`None`. - contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the - first element is the index of the original target token and the second element is the index of the - contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is - not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all - available tokens. Defaults to :obj:`None`. -""" - - -def contrast_fn_docstring(): - def docstring_decorator(fn: StepFunction): - """Returns the docstring for the contrastive step functions.""" - if fn.__doc__ is not None: - if "Args:\n" in fn.__doc__: - fn.__doc__ = fn.__doc__.replace("Args:\n", CONTRAST_FN_ARGS_DOCSTRING) - else: - fn.__doc__ = fn.__doc__ + "\n " + CONTRAST_FN_ARGS_DOCSTRING - else: - fn.__doc__ = CONTRAST_FN_ARGS_DOCSTRING - return fn - - return docstring_decorator - - def logit_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor: """Compute the logit of the target_ids from the model's output logits.""" logits = args.attribution_model.output2logits(args.forward_output) @@ -149,87 +124,27 @@ def perplexity_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor: return 2 ** crossentropy_fn(args) -@contrast_fn_docstring() -def _get_contrast_output( - args: StepFunctionArgs, - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, - return_contrastive_target_ids: bool = False, - **forward_kwargs, -) -> ModelOutput: - """Utility function to return the output of the model for given contrastive inputs. - - Args: - return_contrastive_target_ids (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to return the - contrastive target ids as well as the model output. Defaults to :obj:`False`. - **forward_kwargs: Additional keyword arguments to be passed to the model's forward pass. - """ - c_tgt_ids = None - is_enc_dec = args.attribution_model.is_encoder_decoder - if contrast_targets: - c_batch = DecoderOnlyBatch.from_batch( - get_batch_from_inputs( - attribution_model=args.attribution_model, - inputs=contrast_targets, - as_targets=is_enc_dec, - ) - ) - curr_prefix_len = args.decoder_input_ids.size(1) - if len(contrast_targets_alignments) > 0 and isinstance(contrast_targets_alignments[0], list): - contrast_targets_alignments = contrast_targets_alignments[0] - c_batch, c_tgt_ids = slice_batch_from_position(c_batch, curr_prefix_len, contrast_targets_alignments) - - if args.decoder_input_ids.size(0) != c_batch.target_ids.size(0): - raise ValueError( - f"Contrastive batch size ({c_batch.target_ids.size(0)}) must match candidate batch size" - f" ({args.decoder_input_ids.size(0)}). Multi-sentence attribution and methods expanding inputs to" - " multiple steps (e.g. Integrated Gradients) are not yet supported for contrastive attribution." - ) - - args.decoder_input_ids = c_batch.target_ids - args.decoder_input_embeds = c_batch.target_embeds - args.decoder_attention_mask = c_batch.target_mask - if contrast_sources: - if not (is_enc_dec and isinstance(args, StepFunctionEncoderDecoderArgs)): - raise ValueError( - "Contrastive source inputs can only be used with encoder-decoder models. " - "Use `contrast_targets` to set a contrastive target containing a prefix for decoder-only models." - ) - c_enc_in = args.attribution_model.encode(contrast_sources) - args.encoder_input_ids = c_enc_in.input_ids - args.encoder_attention_mask = c_enc_in.attention_mask - args.encoder_input_embeds = args.attribution_model.embed(args.encoder_input_ids, as_targets=False) - c_batch = args.attribution_model.formatter.convert_args_to_batch(args) - c_out = args.attribution_model.get_forward_output(c_batch, use_embeddings=is_enc_dec, **forward_kwargs) - if return_contrastive_target_ids: - return c_out, c_tgt_ids - return c_out - - @contrast_fn_docstring() def contrast_logits_fn( args: StepFunctionArgs, contrast_sources: Optional[FeatureAttributionInput] = None, contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, + contrast_force_inputs: bool = False, ): """Returns the logit of a generation target given contrastive context or target prediction alternative. If only ``contrast_targets`` are specified, the logit of the contrastive prediction is computed given same context. The logit for the same token given contrastive source/target preceding context can also be computed using ``contrast_sources`` without specifying ``contrast_targets``. """ - c_output, c_tgt_ids = _get_contrast_output( + c_args = _setup_contrast_args( args, contrast_sources=contrast_sources, contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, - return_contrastive_target_ids=True, + contrast_force_inputs=contrast_force_inputs, ) - if c_tgt_ids: - args.target_ids = c_tgt_ids - args.forward_output = c_output - return logit_fn(args) + return logit_fn(c_args) @contrast_fn_docstring() @@ -239,23 +154,21 @@ def contrast_prob_fn( contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, logprob: bool = False, + contrast_force_inputs: bool = False, ): """Returns the probability of a generation target given contrastive context or target prediction alternative. If only ``contrast_targets`` are specified, the probability of the contrastive prediction is computed given same context. The probability for the same token given contrastive source/target preceding context can also be computed using ``contrast_sources`` without specifying ``contrast_targets``. """ - c_output, c_tgt_ids = _get_contrast_output( + c_args = _setup_contrast_args( args, contrast_sources=contrast_sources, contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, - return_contrastive_target_ids=True, + contrast_force_inputs=contrast_force_inputs, ) - if c_tgt_ids: - args.target_ids = c_tgt_ids - args.forward_output = c_output - return probability_fn(args, logprob=logprob) + return probability_fn(c_args, logprob=logprob) @contrast_fn_docstring() @@ -264,7 +177,7 @@ def pcxmi_fn( contrast_sources: Optional[FeatureAttributionInput] = None, contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, - **kwargs, + contrast_force_inputs: bool = False, ) -> SingleScorePerStepTensor: """Compute the pointwise conditional cross-mutual information (P-CXMI) of target ids given original and contrastive input options. The P-CXMI is defined as the negative log-ratio between the conditional probability of the target @@ -277,6 +190,7 @@ def pcxmi_fn( contrast_sources=contrast_sources, contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, + contrast_force_inputs=contrast_force_inputs, ) return -torch.log2(torch.div(original_probs, contrast_probs)) @@ -290,6 +204,7 @@ def kl_divergence_fn( top_k: int = 0, top_p: float = 1.0, min_tokens_to_keep: int = 1, + contrast_force_inputs: bool = False, ) -> SingleScorePerStepTensor: """Compute the pointwise Kullback-Leibler divergence of target ids given original and contrastive input options. The KL divergence is the expectation of the log difference between the probabilities of regular (P) and contrastive @@ -304,7 +219,11 @@ def kl_divergence_fn( min_tokens_to_keep (:obj:`int`): Minimum number of tokens to keep with :obj:`top_p` filtering. Defaults to :obj:`1`. """ - + if not contrast_force_inputs and args.is_attributed_fn: + raise RuntimeError( + "Using KL divergence as attribution target might lead to unexpected results, depending on the attribution" + "method used. Use --contrast_force_inputs in the model.attribute call to proceed." + ) original_logits: torch.Tensor = args.attribution_model.output2logits(args.forward_output) contrast_output = _get_contrast_output( args=args, @@ -313,7 +232,7 @@ def kl_divergence_fn( contrast_targets_alignments=contrast_targets_alignments, return_contrastive_target_ids=False, ) - contrast_logits: torch.Tensor = args.attribution_model.output2logits(contrast_output) + contrast_logits: torch.Tensor = args.attribution_model.output2logits(contrast_output.forward_output) filtered_original_logits, filtered_contrast_logits = filter_logits( original_logits=original_logits, contrast_logits=contrast_logits, @@ -338,6 +257,7 @@ def contrast_prob_diff_fn( contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, logprob: bool = False, + contrast_force_inputs: bool = False, ): """Returns the difference between next step probability for a candidate generation target vs. a contrastive alternative. Can be used as attribution target to answer the question: "Which features were salient in the @@ -353,6 +273,7 @@ def contrast_prob_diff_fn( contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, logprob=logprob, + contrast_force_inputs=contrast_force_inputs, ) return model_probs - contrast_probs @@ -363,6 +284,7 @@ def contrast_logits_diff_fn( contrast_sources: Optional[FeatureAttributionInput] = None, contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, + contrast_force_inputs: bool = False, ): """Equivalent to ``contrast_prob_diff_fn`` but for logits. The original target function used in `Yin and Neubig (2022) `__ @@ -373,6 +295,7 @@ def contrast_logits_diff_fn( contrast_sources=contrast_sources, contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, + contrast_force_inputs=contrast_force_inputs, ) return model_logits - contrast_logits @@ -383,6 +306,7 @@ def in_context_pvi_fn( contrast_sources: Optional[FeatureAttributionInput] = None, contrast_targets: Optional[FeatureAttributionInput] = None, contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, + contrast_force_inputs: bool = False, ): """Returns the in-context pointwise V-usable information as defined by `Lu et al. (2023) `__. In-context PVI is a variant of P-CXMI that captures the amount of usable @@ -400,6 +324,7 @@ def in_context_pvi_fn( contrast_targets=contrast_targets, contrast_targets_alignments=contrast_targets_alignments, logprob=True, + contrast_force_inputs=contrast_force_inputs, ) return -orig_logprob + contrast_logprob diff --git a/inseq/data/batch.py b/inseq/data/batch.py index dde9a0f5..c72b95c8 100644 --- a/inseq/data/batch.py +++ b/inseq/data/batch.py @@ -235,6 +235,8 @@ def from_batch(self, batch: Batch) -> "DecoderOnlyBatch": def slice_batch_from_position( batch: DecoderOnlyBatch, curr_idx: int, alignments: Optional[List[Tuple[int, int]]] = None ) -> Tuple[DecoderOnlyBatch, IdsTensor]: + if len(alignments) > 0 and isinstance(alignments[0], list): + alignments = alignments[0] truncate_idx = get_aligned_idx(curr_idx, alignments) tgt_ids = batch.target_ids[:, truncate_idx] return batch[:truncate_idx], tgt_ids diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py index a7c20a82..e95e168b 100644 --- a/inseq/models/attribution_model.py +++ b/inseq/models/attribution_model.py @@ -133,6 +133,7 @@ def format_step_function_args( forward_output: ModelOutput, target_ids: ExpandedTargetIdsTensor, batch: DecoderOnlyBatch, + is_attributed_fn: bool = False, ) -> StepFunctionArgs: raise NotImplementedError() @@ -650,7 +651,7 @@ def _forward( output = self.get_forward_output(batch, use_embeddings=use_embeddings, **kwargs) logger.debug(f"logits: {pretty_tensor(output.logits)}") step_fn_args = self.formatter.format_step_function_args( - attribution_model=self, forward_output=output, target_ids=target_ids, batch=batch + attribution_model=self, forward_output=output, target_ids=target_ids, is_attributed_fn=True, batch=batch ) step_fn_extra_args = {k: v for k, v in zip(attributed_fn_argnames, args) if v is not None} return attributed_fn(step_fn_args, **step_fn_extra_args) diff --git a/inseq/models/decoder_only.py b/inseq/models/decoder_only.py index 2c633be4..b5586ef3 100644 --- a/inseq/models/decoder_only.py +++ b/inseq/models/decoder_only.py @@ -136,11 +136,13 @@ def format_step_function_args( forward_output: ModelOutput, target_ids: ExpandedTargetIdsTensor, batch: DecoderOnlyBatch, + is_attributed_fn: bool = False, ) -> StepFunctionDecoderOnlyArgs: return StepFunctionDecoderOnlyArgs( attribution_model=attribution_model, forward_output=forward_output, target_ids=target_ids, + is_attributed_fn=is_attributed_fn, decoder_input_ids=batch.target_ids, decoder_attention_mask=batch.target_mask, decoder_input_embeds=batch.target_embeds, diff --git a/inseq/models/encoder_decoder.py b/inseq/models/encoder_decoder.py index bbb5dbad..c3058fc8 100644 --- a/inseq/models/encoder_decoder.py +++ b/inseq/models/encoder_decoder.py @@ -177,11 +177,13 @@ def format_step_function_args( forward_output: ModelOutput, target_ids: ExpandedTargetIdsTensor, batch: EncoderDecoderBatch, + is_attributed_fn: bool = False, ) -> StepFunctionEncoderDecoderArgs: return StepFunctionEncoderDecoderArgs( attribution_model=attribution_model, forward_output=forward_output, target_ids=target_ids, + is_attributed_fn=is_attributed_fn, encoder_input_ids=batch.source_ids, decoder_input_ids=batch.target_ids, encoder_input_embeds=batch.source_embeds, diff --git a/inseq/utils/alignment_utils.py b/inseq/utils/alignment_utils.py index 14271677..7a3ac877 100644 --- a/inseq/utils/alignment_utils.py +++ b/inseq/utils/alignment_utils.py @@ -319,8 +319,8 @@ def get_adjusted_alignments( # Default behavior: fill missing alignments with 1:1 position alignments starting from the bottom of the # two sequences if not match_pairs: - if len(contrast_tokens) < step_idx: - filled_alignments.append((pair_idx, 0)) + if (len(contrast_tokens) - step_idx) < start_pos: + filled_alignments.append((pair_idx, len(contrast_tokens) - 1)) else: filled_alignments.append((pair_idx, len(contrast_tokens) - step_idx)) else: @@ -329,11 +329,12 @@ def get_adjusted_alignments( valid_match = match_pairs_unaligned[0] if match_pairs_unaligned else match_pairs[0] filled_alignments.append(valid_match) if alignments != filled_alignments: + alignments = filled_alignments logger.warning( f"Provided alignments do not cover all {end_pos - start_pos} tokens from the original" - " sequence.\nFilling missing position with right-aligned 1:1 position alignments." + " sequence.\nFilling missing position with right-aligned 1:1 position alignments.\n" + f"Generated alignments: {alignments}" ) - alignments = filled_alignments if is_auto_aligned: logger.warning( f"Using {ALIGN_MODEL_ID} for automatic alignments. Provide custom alignments for non-linguistic " diff --git a/inseq/utils/contrast_utils.py b/inseq/utils/contrast_utils.py new file mode 100644 index 00000000..ca342181 --- /dev/null +++ b/inseq/utils/contrast_utils.py @@ -0,0 +1,156 @@ +import logging +import warnings +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union + +from transformers.modeling_outputs import ModelOutput + +from ..data import ( + DecoderOnlyBatch, + EncoderDecoderBatch, + FeatureAttributionInput, + get_batch_from_inputs, + slice_batch_from_position, +) +from ..utils.typing import TargetIdsTensor + +if TYPE_CHECKING: + from ..attr.step_functions import StepFunction, StepFunctionArgs + +logger = logging.getLogger(__name__) + +CONTRAST_FN_ARGS_DOCSTRING = """Args: + contrast_sources (:obj:`str` or :obj:`list(str)`): Source text(s) used as contrastive inputs to compute + the contrastive step function for encoder-decoder models. If not specified, the source text is assumed to + match the original source text. Defaults to :obj:`None`. + contrast_targets (:obj:`str` or :obj:`list(str)`): Contrastive target text(s) to be compared to the original + target text. If not specified, the original target text is used as contrastive target (will result in same + output unless ``contrast_sources`` are specified). Defaults to :obj:`None`. + contrast_targets_alignments (:obj:`list(tuple(int, int))`, `optional`): A list of tuples of indices, where the + first element is the index of the original target token and the second element is the index of the + contrastive target token, used only if :obj:`contrast_targets` is specified. If an explicit alignment is + not specified, the alignment of the original and contrastive target texts is assumed to be 1:1 for all + available tokens. Defaults to :obj:`None`. + contrast_force_inputs (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to force the contrastive + inputs to be used for attribution. Defaults to :obj:`False`. +""" + + +def contrast_fn_docstring() -> Callable[..., "StepFunction"]: + def docstring_decorator(fn: "StepFunction") -> "StepFunction": + """Returns the docstring for the contrastive step functions.""" + if fn.__doc__ is not None: + if "Args:\n" in fn.__doc__: + fn.__doc__ = fn.__doc__.replace("Args:\n", CONTRAST_FN_ARGS_DOCSTRING) + else: + fn.__doc__ = fn.__doc__ + "\n " + CONTRAST_FN_ARGS_DOCSTRING + else: + fn.__doc__ = CONTRAST_FN_ARGS_DOCSTRING + return fn + + return docstring_decorator + + +@dataclass +class ContrastOutput: + forward_output: ModelOutput + batch: Union[EncoderDecoderBatch, DecoderOnlyBatch, None] = None + target_ids: Optional[TargetIdsTensor] = None + + +def _get_contrast_output( + args: "StepFunctionArgs", + contrast_sources: Optional[FeatureAttributionInput] = None, + contrast_targets: Optional[FeatureAttributionInput] = None, + contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, + return_contrastive_target_ids: bool = False, + return_contrastive_batch: bool = False, + **forward_kwargs, +) -> ContrastOutput: + """Utility function to return the output of the model for given contrastive inputs. + + Args: + return_contrastive_target_ids (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to return the + contrastive target ids as well as the model output. Defaults to :obj:`False`. + **forward_kwargs: Additional keyword arguments to be passed to the model's forward pass. + """ + c_tgt_ids = None + is_enc_dec = args.attribution_model.is_encoder_decoder + if contrast_targets: + c_batch = DecoderOnlyBatch.from_batch( + get_batch_from_inputs( + attribution_model=args.attribution_model, + inputs=contrast_targets, + as_targets=is_enc_dec, + ) + ) + curr_prefix_len = args.decoder_input_ids.size(1) + c_batch, c_tgt_ids = slice_batch_from_position(c_batch, curr_prefix_len, contrast_targets_alignments) + + if args.decoder_input_ids.size(0) != c_batch.target_ids.size(0): + raise ValueError( + f"Contrastive batch size ({c_batch.target_ids.size(0)}) must match candidate batch size" + f" ({args.decoder_input_ids.size(0)}). Multi-sentence attribution and methods expanding inputs to" + " multiple steps (e.g. Integrated Gradients) are not yet supported for contrastive attribution." + ) + + args.decoder_input_ids = c_batch.target_ids + args.decoder_input_embeds = c_batch.target_embeds + args.decoder_attention_mask = c_batch.target_mask + if contrast_sources: + from ..attr.step_functions import StepFunctionEncoderDecoderArgs + + if not (is_enc_dec and isinstance(args, StepFunctionEncoderDecoderArgs)): + raise ValueError( + "Contrastive source inputs can only be used with encoder-decoder models. " + "Use `contrast_targets` to set a contrastive target containing a prefix for decoder-only models." + ) + c_enc_in = args.attribution_model.encode(contrast_sources) + args.encoder_input_ids = c_enc_in.input_ids + args.encoder_attention_mask = c_enc_in.attention_mask + args.encoder_input_embeds = args.attribution_model.embed(args.encoder_input_ids, as_targets=False) + c_batch = args.attribution_model.formatter.convert_args_to_batch(args) + return ContrastOutput( + forward_output=args.attribution_model.get_forward_output(c_batch, use_embeddings=is_enc_dec, **forward_kwargs), + batch=c_batch if return_contrastive_batch else None, + target_ids=c_tgt_ids if return_contrastive_target_ids else None, + ) + + +def _setup_contrast_args( + args: "StepFunctionArgs", + contrast_sources: Optional[FeatureAttributionInput] = None, + contrast_targets: Optional[FeatureAttributionInput] = None, + contrast_targets_alignments: Optional[List[List[Tuple[int, int]]]] = None, + contrast_force_inputs: bool = False, +): + c_output = _get_contrast_output( + args, + contrast_sources=contrast_sources, + contrast_targets=contrast_targets, + contrast_targets_alignments=contrast_targets_alignments, + return_contrastive_target_ids=True, + return_contrastive_batch=True, + ) + if c_output.target_ids is not None: + args.target_ids = c_output.target_ids + if args.is_attributed_fn: + if contrast_force_inputs: + warnings.warn( + "Forcing contrastive inputs to be used for attribution. This may result in unexpected behavior for " + "gradient-based attribution methods.", + stacklevel=1, + ) + args.forward_output = c_output.forward_output + else: + warnings.warn( + "Contrastive inputs do not match original inputs when using a contrastive attributed function.\n" + "By default we force the original inputs to be used (i.e. only the contrastive predicted target is " + "different).\nThis is a requirement for gradient-based attribution method, as contrastive inputs don't" + " participate in gradient computation.\nFor attribution methods with less stringent requirements, " + "set --contrast_force_inputs to True to use the contrastive inputs for attribution instead.", + stacklevel=1, + ) + else: + args.forward_output = c_output.forward_output + return args diff --git a/tests/attr/feat/test_step_functions.py b/tests/attr/feat/test_step_functions.py index 89400ea5..b4a8fcec 100644 --- a/tests/attr/feat/test_step_functions.py +++ b/tests/attr/feat/test_step_functions.py @@ -1,6 +1,8 @@ +import torch from pytest import fixture import inseq +from inseq.attr.step_functions import StepFunctionArgs, _get_contrast_output, probability_fn from inseq.models import DecoderOnlyAttributionModel, EncoderDecoderAttributionModel @@ -51,3 +53,79 @@ def test_contrast_prob_consistency_enc_dec(saliency_mt_model: EncoderDecoderAttr ) regular_prob = out_regular.sequence_attributions[0].step_scores["probability"] assert all(c == r for c, r in zip(contrast_prob, regular_prob[-len(contrast_prob) :])) + + +def attr_prob_diff_fn( + args: StepFunctionArgs, + contrast_targets, + contrast_targets_alignments=None, + logprob: bool = False, +): + model_probs = probability_fn(args, logprob=logprob) + c_out = _get_contrast_output( + args, + contrast_targets=contrast_targets, + contrast_targets_alignments=contrast_targets_alignments, + return_contrastive_target_ids=True, + ) + args.target_ids = c_out.target_ids + contrast_probs = probability_fn(args, logprob=logprob) + return model_probs - contrast_probs + + +def test_contrast_attribute_target_only_enc_dec(saliency_mt_model: EncoderDecoderAttributionModel): + inseq.register_step_function(fn=attr_prob_diff_fn, identifier="attr_prob_diff", overwrite=True) + src = "The nurse was tired and went home." + tgt = "L'infermiere era stanco e andò a casa." + contrast_tgt = "L'infermiera era stanca e andò a casa." + out_explicit_logit_prob_diff = saliency_mt_model.attribute( + src, + tgt, + contrast_targets=contrast_tgt, + attributed_fn="attr_prob_diff", + step_scores=["attr_prob_diff", "contrast_prob_diff"], + attribute_target=True, + ) + out_default_prob_diff = saliency_mt_model.attribute( + src, + tgt, + contrast_targets=contrast_tgt, + attributed_fn="contrast_prob_diff", + step_scores=["contrast_prob_diff"], + attribute_target=True, + ) + assert torch.allclose( + out_explicit_logit_prob_diff[0].step_scores["contrast_prob_diff"], + out_default_prob_diff[0].step_scores["contrast_prob_diff"], + ) + assert torch.allclose( + out_explicit_logit_prob_diff[0].source_attributions, + out_default_prob_diff[0].source_attributions, + ) + assert torch.allclose( + out_explicit_logit_prob_diff[0].target_attributions, + out_default_prob_diff[0].target_attributions, + equal_nan=True, + ) + out_contrast_force_inputs_prob_diff = saliency_mt_model.attribute( + src, + tgt, + contrast_targets=contrast_tgt, + attributed_fn="contrast_prob_diff", + step_scores=["contrast_prob_diff"], + attribute_target=True, + contrast_force_inputs=True, + ) + assert not torch.allclose( + out_explicit_logit_prob_diff[0].source_attributions, + out_contrast_force_inputs_prob_diff[0].source_attributions, + ) + assert not torch.allclose( + out_explicit_logit_prob_diff[0].target_attributions, + out_contrast_force_inputs_prob_diff[0].target_attributions, + equal_nan=True, + ) + assert torch.allclose( + out_explicit_logit_prob_diff[0].step_scores["contrast_prob_diff"], + out_default_prob_diff[0].step_scores["contrast_prob_diff"], + )