Skip to content

Commit

Permalink
[HybridCache] Fix get_seq_length method (#31661)
Browse files Browse the repository at this point in the history
* fix gemma2

* handle in generate
  • Loading branch information
sanchit-gandhi authored Jun 27, 2024
1 parent 464aa74 commit 1c68f2c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,7 +1083,7 @@ def get_max_length(self) -> Optional[int]:
# no matter how long the sentence is
return self.max_cache_len

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0):
return None

def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
cache = model_kwargs["past_key_values"]
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length"):
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()

if "inputs_embeds" in model_kwargs:
Expand Down

0 comments on commit 1c68f2c

Please sign in to comment.