-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: Mistral/Mixtral FA2 cache fix when going beyond the context window #28037
Conversation
@@ -385,11 +385,16 @@ def forward( | |||
|
|||
if past_key_value is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
context: when use_cache is True
, past_key_value
is now a Cache
object even if it is an empty cache (previously it was None
).
As such, with a slicing window, we need to check whether the cache has contents before attempting to slice, as we can't slice None
.
@@ -400,8 +405,6 @@ def forward( | |||
f" {past_key.shape}" | |||
) | |||
|
|||
past_key_value = (past_key, past_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
past_key_value
is now a Cache
instance that is updated in place with the .update()
function (L413 in the updated file). We don't need to set it.
slicing_tokens = 1 - self.config.sliding_window | ||
|
||
past_key = past_key_value[0] | ||
past_value = past_key_value[1] | ||
past_key = past_key_value[self.layer_idx][0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.layer_idx
is guaranteed to be defined here, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, good catch! Going to add an appropriate exception.
This indeed seems to address what was described here, well done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
Looking forward to the FA2 tests being run 😅
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
As @younesbelkada mentioined and in the official site
nothing we can't do unless we run on a different machine |
What does this PR do?
The FA2 code path was indexing the
Cache
object incorrectly. This PR fixes it.Fixes #27985
NOTE:
tests/models/mistral/test_modeling_mistral.py::MistralIntegrationTest::test_model_7b_long_prompt
(slow test) was failing onmain
, but it was not popping up in our daily slow CI 🤔 because of that, this issue flew under the radar. It is passing now.Edit: the test was not run because we are skipping FA2 tests (
@require_flash_attn
). @ydshieh is on it :)