Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: don't throw warnings on gemma2 when instantiating a new cache #33595

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,15 @@ def get_max_length(self) -> Optional[int]:
return self.max_cache_len

def get_seq_length(self, layer_idx: Optional[int] = 0):
return None
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
if layer_idx != 0:
raise ValueError(
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

def reset(self):
"""Resets the cache values while preserving the objects"""
Expand Down
41 changes: 14 additions & 27 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,20 +710,13 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
config.n_positions - 1]`.

[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
past_key_values (`HybridCache`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.

The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
Gemma 2 uses a unique cache class, [`HybridCache`], and does not guarantee full compatibility with other
cache classes.

If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
Expand Down Expand Up @@ -789,7 +782,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -818,19 +811,8 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if cache_position is None:
if past_key_values is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
else:
raise ValueError("When `past_key_values` is passed, `cache_position` must be too")

# Probably a forward call with caching, so we set up cache for one call only
if use_cache and past_key_values is None and not self.training:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two changes here, both to be consistent with other models:

  1. self.training should not control whether we instantiate a cache
  2. If a user respects the types in the docs, past_key_values is either a Cache or we instantiate a new one for the user without warnings

logger.warning_once(
"You are calling the model with `use_cache=True` but didn't pass `past_key_values` while not training. ",
"If you want to compute with cache, make sure to pass an instance of `HybridCache`. An empty `HybridCache` instance "
"will be created for this call. See for more: (https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)",
)
# Instantiate an empty cache if needed.
if use_cache and past_key_values is None:
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
Expand All @@ -840,6 +822,11 @@ def forward(
dtype=inputs_embeds.dtype,
)

if cache_position is None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy/paste from llama (and other Cache-supporting models)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey, this should always work actually since the seq length gets layer_idx=0. Just one question, isn't it a bit misleading if some layers will have get_seq_length() number of tokens while others no more than sliding window length?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp yes, if get_seq_length gets called on the wrong layer we will have problems! I'm going to add an exception if it gets called on layer_idx != 0 (I doubt we need it).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey sounds good, as long as the function of get_seq_length is transparent for users, to reduce number of cache-related question we get 😄

past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)

Expand Down Expand Up @@ -912,7 +899,7 @@ def _update_causal_mask(
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
past_key_values: HybridCache,
output_attentions: bool,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
Expand Down Expand Up @@ -981,7 +968,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down Expand Up @@ -1202,7 +1189,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down
12 changes: 10 additions & 2 deletions src/transformers/models/mimi/modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,8 +1000,16 @@ def forward(
)
use_cache = False

if use_cache and past_key_values is None and not self.training:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if use_cache and not isinstance(past_key_values, Cache):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy/paste from llama (and other Cache-supporting models)

if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down
5 changes: 5 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,15 @@ def setUp(self):
def test_model_outputs_equivalence(self, **kwargs):
pass

@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this parameterized, the intended overwriting was not happening

@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_inference(self):
pass

@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
def test_eager_matches_sdpa_generate(self):
pass

@parameterized.expand([("random",), ("same",)])
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
Expand Down
Loading