-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache: don't throw warnings on gemma2
when instantiating a new cache
#33595
Conversation
gemma2
when instantiating a new cache
src/transformers/cache_utils.py
Outdated
def get_seq_length(self, layer_idx: Optional[int] = 0): | ||
return None | ||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
# limit the check to the first batch member and head dimension. | ||
# TODO: deprecate this function in favor of `cache_position` | ||
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HybridCache
is a StaticCache
with alternating sliding window layers. The method to retrieve the cache length is copy/paste from StaticCache
We will want to use another method in the future, but let's leave this as a copy of StaticCache
for now. This method is needed in the updated gemma 2.
raise ValueError("When `past_key_values` is passed, `cache_position` must be too") | ||
|
||
# Probably a forward call with caching, so we set up cache for one call only | ||
if use_cache and past_key_values is None and not self.training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two changes here, both to be consistent with other models:
self.training
should not control whether we instantiate a cache- If a user respects the types in the docs,
past_key_values
is either aCache
or we instantiate a new one for the user without warnings
@@ -840,6 +822,11 @@ def forward( | |||
dtype=inputs_embeds.dtype, | |||
) | |||
|
|||
if cache_position is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy/paste from llama (and other Cache
-supporting models)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okey, this should always work actually since the seq length gets layer_idx=0
. Just one question, isn't it a bit misleading if some layers will have get_seq_length()
number of tokens while others no more than sliding window length?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp yes, if get_seq_length
gets called on the wrong layer we will have problems! I'm going to add an exception if it gets called on layer_idx != 0 (I doubt we need it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okey sounds good, as long as the function of get_seq_length
is transparent for users, to reduce number of cache-related question we get 😄
@@ -1000,8 +1000,16 @@ def forward( | |||
) | |||
use_cache = False | |||
|
|||
if use_cache and past_key_values is None and not self.training: | |||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |||
if use_cache and not isinstance(past_key_values, Cache): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy/paste from llama (and other Cache-supporting models)
@@ -86,10 +86,15 @@ def setUp(self): | |||
def test_model_outputs_equivalence(self, **kwargs): | |||
pass | |||
|
|||
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
without this parameterized
, the intended overwriting was not happening
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Please merge once @zucchini-nlp has approved as she knows this code more than I.
cc @BenjaminBossan as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for cleaning up warnings! Left one question about HybridCache, since I was reluctant to add seq-length for that cache type where lengths are not consistent over layers
@@ -840,6 +822,11 @@ def forward( | |||
dtype=inputs_embeds.dtype, | |||
) | |||
|
|||
if cache_position is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okey, this should always work actually since the seq length gets layer_idx=0
. Just one question, isn't it a bit misleading if some layers will have get_seq_length()
number of tokens while others no more than sliding window length?
I'm not qualified to review this but thanks for addressing this so quickly. |
What does this PR do?
Related to #33541
The warning in question should only be thrown in the case we are converting from a legacy cache, which will be deprecated soon. Gemma 2 doesn't support the legacy cache format, so no warning should ever be thrown :)
In the process, updates a few related inconsistencies.
✅ slow
gemma2
tests ran locally. There are a few failures (also present on main). Some failures were fixed in this PR.