Skip to content

Commit

Permalink
fix gemma2
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Jun 27, 2024
1 parent 75a6319 commit b94f46d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,11 @@ def get_max_length(self) -> Optional[int]:
return self.max_cache_len

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return None
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

def reset(self):
"""Resets the cache values while preserving the objects"""
Expand Down

0 comments on commit b94f46d

Please sign in to comment.