diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py index fde51c18..d478fd09 100644 --- a/inseq/data/aggregator.py +++ b/inseq/data/aggregator.py @@ -691,10 +691,10 @@ class SubwordAggregator(ContiguousSpanAggregator): preserved (e.g. [0.3, -0.7, 0.1] -> -0.7). aggregate_source (bool, optional): Whether to aggregate over the source sequence. Defaults to True. aggregate_target (bool, optional): Whether to aggregate over the target sequence. Defaults to True. - special_symbol (str, optional): Symbol used to identify subwords. Defaults to '▁', used by SentencePiece. - If is_suffix_symbol=True, then this symbol is used to identify parts to be aggregated (e.g. # in WordPiece, - ['phen', '##omen', '##al']). Otherwise, it identifies the roots that should be preserved (e.g. ▁ in - SentencePiece, ['▁phen', 'omen', 'al']). + special_chars (str or tuple of str, optional): One or more characters used to identify subword boundaries. + Defaults to '▁', used by SentencePiece. If is_suffix_symbol=True, then this symbol is used to identify + parts to be aggregated (e.g. # in WordPiece, ['phen', '##omen', '##al']). Otherwise, it identifies the + roots that should be preserved (e.g. ▁ in SentencePiece, ['▁phen', 'omen', 'al']). is_suffix_symbol (bool, optional): Whether the special symbol is used to identify suffixes or prefixes. Defaults to False. """ @@ -707,33 +707,33 @@ def aggregate( attr: "FeatureAttributionSequenceOutput", aggregate_source: bool = True, aggregate_target: bool = True, - special_symbol: str = "▁", + special_chars: Union[str, Tuple[str, ...]] = "▁", is_suffix_symbol: bool = False, **kwargs, ): source_spans = [] target_spans = [] if aggregate_source: - source_spans = cls.get_spans(attr.source, special_symbol, is_suffix_symbol) + source_spans = cls.get_spans(attr.source, special_chars, is_suffix_symbol) if aggregate_target: - target_spans = cls.get_spans(attr.target, special_symbol, is_suffix_symbol) + target_spans = cls.get_spans(attr.target, special_chars, is_suffix_symbol) return super().aggregate(attr, source_spans=source_spans, target_spans=target_spans, **kwargs) @staticmethod - def get_spans(tokens: List[TokenWithId], special_symbol: str, is_suffix_symbol: bool): + def get_spans(tokens: List[TokenWithId], special_chars: Union[str, Tuple[str, ...]], is_suffix_symbol: bool): spans = [] last_prefix_idx = 0 - has_special_symbol = any(sym in token.token for token in tokens for sym in special_symbol) - if not has_special_symbol: + has_special_chars = any(sym in token.token for token in tokens for sym in special_chars) + if not has_special_chars: logger.warning( - f"ATTENTION: The {special_symbol} symbol is currently used for subword aggregation, but no instances " - "have been detected in the sequence. Change the special symbols using e.g. special_symbol=('Ġ', 'Ċ')" + f"The {special_chars} character is currently used for subword aggregation, but no instances " + "have been detected in the sequence. Change the special symbols using e.g. special_chars=('Ġ', 'Ċ')" ", and set is_suffix_symbol=True if they are used as suffix word separators (e.g. Hello world)" ) return spans for curr_idx, token in enumerate(tokens): # Suffix if token start with special suffix symbol, or if it doesn't have the special prefix symbol. - is_suffix = token.token.startswith(special_symbol) == is_suffix_symbol + is_suffix = token.token.startswith(special_chars) == is_suffix_symbol if is_suffix: if curr_idx == len(tokens) - 1 and curr_idx - last_prefix_idx > 1: spans.append((last_prefix_idx, curr_idx)) diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py index 5f16d0d5..65ea6b30 100644 --- a/inseq/utils/torch_utils.py +++ b/inseq/utils/torch_utils.py @@ -143,26 +143,37 @@ def aggregate_contiguous( t: torch.Tensor, spans: Sequence[Tuple[int, int]], aggregate_fn: Optional[Callable] = None, - aggregate_dim: int = 1, + aggregate_dim: int = 0, ): + """Given a tensor, aggregate contiguous spans of the tensor along a given dimension using the provided + aggregation function. If no aggregation function is provided, the mean is used. + + Args: + t: Tensor to aggregate + spans: Sequence of (start, end) tuples indicating contiguous spans to aggregate + aggregate_fn: Aggregation function to use. If None, torch.mean is used. + aggregate_dim: Dimension to aggregate along. Default is 0. + """ if not spans: return t if aggregate_fn is None: aggregate_fn = torch.mean - while t.ndim < 2: - t = t.unsqueeze(-1) - t = t.transpose(aggregate_dim, 1) + if aggregate_dim > t.ndim: + raise ValueError(f"aggregate_dim {aggregate_dim} is greater than tensor dimension {t.ndim}") + if aggregate_dim != 0: + t = t.transpose(aggregate_dim, 0) slices = [] base_val = 0 for start, end in spans: if start > base_val: - slices.append(t[:, base_val:start, ...]) - slices.append(aggregate_fn(t[:, start:end, ...], dim=1).unsqueeze(1)) + slices.append(t[base_val:start, ...]) + slices.append(aggregate_fn(t[start:end, ...], dim=0).unsqueeze(0)) base_val = end - slices.append(t[:, base_val:]) - out_cat = torch.cat(slices, dim=1).transpose(1, aggregate_dim) - if 1 in out_cat.shape: - out_cat = out_cat.transpose(1, 0).squeeze(0) + if base_val < t.shape[0]: + slices.append(t[base_val:, ...]) + out_cat = torch.cat(slices, dim=0) + if aggregate_dim != 0: + out_cat = out_cat.transpose(aggregate_dim, 0) return out_cat @@ -174,8 +185,8 @@ def get_front_padding(t: torch.Tensor, pad: int = 0, dim: int = 1) -> List[int]: def get_sequences_from_batched_steps(bsteps: List[torch.Tensor]) -> List[torch.Tensor]: - """Given a sequence of batched step tensors of shape (batch_size, ...) builds a sequence - of tensors of shape (len(sequence), ...) where each resulting tensor is the aggregation + """Given a sequence of batched step tensors of shape (batch_size, seq_len, ...) builds a sequence + of tensors of shape (seq_len, ...) where each resulting tensor is the aggregation across batch steps for every batch element. Input tensors will be padded with nans up to max length in non-uniform dimensions to allow for stacking. diff --git a/tests/data/test_aggregator.py b/tests/data/test_aggregator.py index 92cbf80f..eb5086ca 100644 --- a/tests/data/test_aggregator.py +++ b/tests/data/test_aggregator.py @@ -63,7 +63,7 @@ def test_continuous_span_aggregator(saliency_mt_model: HuggingfaceEncoderDecoder def test_span_aggregator_with_prefix(saliency_gpt_model: HuggingfaceDecoderOnlyModel): out = saliency_gpt_model.attribute("Hello, world! I am,:.", "Hello, world! I am,:.!,. Last") - aggregated = out.aggregate("subwords", special_symbol=("Ġ", "Ċ")).aggregate() + aggregated = out.aggregate("subwords", special_chars=("Ġ", "Ċ")).aggregate() assert aggregated[0].target_attributions.shape == (5, 2) assert aggregated[0].attr_pos_start == 3 assert aggregated[0].attr_pos_end == 5