Skip to content

Commit

Permalink
Fix nans in attn, add pad if missing (#269)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti authored Apr 29, 2024
1 parent f2b9d92 commit 8de5e29
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
6 changes: 0 additions & 6 deletions inseq/attr/feat/internals_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import logging
from typing import Any, Optional

import torch
from captum._utils.typing import TensorOrTupleOfTensorsGeneric

from ...data import MultiDimensionalFeatureAttributionStepOutput
Expand Down Expand Up @@ -77,11 +76,6 @@ def attribute(
# We adopt the format [batch_size, sequence_length, sequence_length, num_layers, num_heads]
# for consistency with other multi-unit methods (e.g. gradient attribution)
decoder_self_attentions = decoder_self_attentions.to("cpu").clone().permute(0, 4, 3, 1, 2)
decoder_self_attentions = torch.where(
decoder_self_attentions == 0,
(torch.ones_like(decoder_self_attentions) * float("nan")),
decoder_self_attentions,
)
if self.forward_func.is_encoder_decoder:
sequence_scores = {}
if len(inputs) > 1:
Expand Down
15 changes: 12 additions & 3 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,23 @@ def __init__(
self.tokenizer = tokenizer
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
if self.model.config.pad_token_id is not None:
self.pad_token = self._convert_ids_to_tokens(self.model.config.pad_token_id, skip_special_tokens=False)
self.eos_token_id = getattr(self.model.config, "eos_token_id", None)
pad_token_id = self.model.config.pad_token_id
if pad_token_id is None:
if self.tokenizer.pad_token_id is None:
logger.info(f"Setting `pad_token_id` to `eos_token_id`:{self.eos_token_id} for open-end generation.")
pad_token_id = self.eos_token_id
else:
pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self._convert_ids_to_tokens(pad_token_id, skip_special_tokens=False)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.pad_token
if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = pad_token_id
self.bos_token_id = getattr(self.model.config, "decoder_start_token_id", None)
if self.bos_token_id is None:
self.bos_token_id = self.model.config.bos_token_id
self.bos_token = self._convert_ids_to_tokens(self.bos_token_id, skip_special_tokens=False)
self.eos_token_id = getattr(self.model.config, "eos_token_id", None)
if self.eos_token_id is None:
self.eos_token_id = self.tokenizer.pad_token_id
if self.tokenizer.unk_token_id is None:
Expand Down

0 comments on commit 8de5e29

Please sign in to comment.