-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Static Cache: no mandatory cache_positions
input
#29221
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cc @fxmarty :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that is the way we should go. For now have not seen issue with having a new argument, we added paddin_mask
at some point.
IMO we should move to having cache_positions
in generate as this is more explicit, easier to maintain and less error prone.
More than that, these operation are input dependent while being vectorized they add complexity when there should not be!
No specific opinon, just that using transformers/src/transformers/models/llama/modeling_llama.py Lines 975 to 983 in 75ed76e
Only thing I think: |
EDIT -- before you read why I believe removing ✅ much simpler computations (@ArthurZucker addresses your concerns, although the previous version also did not introduce slowdowns)
@fxmarty I agree here, that it should be renamed :) e.g.
@ArthurZucker I strongly disagree, more redundant inputs result in more bugs 🐛 From the full
This is arguably much more work than enforcing the correct pairing in the models themselves! On top of that, other functions like My goal with this PR is to iron out future sources of bugs before we roll out these changes to other models. Better than well-explained interfaces is... no need for that interface at all! Without these changes, users have to learn how to prepare a new input tensor to use static caches. In terms of complexity, it will always exist. The difference is whether it exists outside In terms of performance, we can see that it is negligible in eager forwards and none in compiled static forwards. Again, if not implemented here, it will be implemented in |
src/transformers/cache_utils.py
Outdated
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
# limit the check to the first batch member and head dimension. | ||
return (self.key_cache[0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of this. This would be run at each forward. Couldn't we just increment somewhere, or not use get_seq_length
at all in the forward?
Also given the issues we've had: did you check it works with compile (meaning: self.model.layers[0].past_key_values.get_seq_length()
is correct inbetween each forward
call in generate
when the model is compiled with model = torch.compile(model, mode="reduce-overhead")
? I removed the requirement on get_seq_length
+ seen_tokens
as we were getting wrong values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fxmarty yes, this version works! e.g. try running RUN_SLOW=1 py.test tests/test_cache_utils.py::CacheIntegrationTest::test_static_cache_greedy_decoding_pad_left_0_eager
on the latest commit, which confirms that all dynamic, eager static cache, and compiled static cache get the same sensible outputs.
I'm not a fan of this. This would be run at each forward. Couldn't we just increment somewhere?
I am very much aligned with your comment! Indeed that was my original plan by adding seen_tokens
to the original cache implementation. However, as you mentioned, we haven't found a way to make it work at compile time. Perhaps we shall add a TODO to revisit this in the future, after the issue you opened on the PyTorch side has been addressed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- if part of the previous cache happens to be zero tensor (very rare in full precision, maybe not in half precision, we had a similar issue with llava when relying on tensors being zeros)
- cost is ok I guess
- it's totally implicit 😢 and that's less aligned with our overall philosophy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your great explanation.
TLDR; I think part of this choice comes down to being implicit
vs explicit
.
Your approach works, but it's implicit. Both have ➕ and ➖ for maintainability.
I am leaning a lot more towards the explicit
for a few reasons:
- If you generate the
cache_positions
in the model withtorch.arange
, the generations you will get with cuda graphs will be wrong precisely because the cache positions are not given as input. I tried this and it does not work. - ➕ Instead of relying on hidden
seen_tokens
and 3 differentget_length
,get_usable_lenght
,get_max_length
, you just needcache_positions
in generate. - ➕ Stateless is less error prone,
cache_positions
is also pretty known from GPT-Fast. -
From the full input_ids and pad_token_id we can derive attention_mask, position_ids, and cache_positions
Is not true: if you want to do paged attention / packed training with a custom 4D attention mask or anything a bit outside the classic forward (any custom Cache class), then you cannot only rely on input_ids and pad token. - The
generate
function should keep track of where it is, for simplicity but also because otherwise your not really able to manipulate anything outside the forward of the model. So at the end of the day you have a loop in generate, which is the onlypublic
place where we use the past_key_values, but you don't let it handle thecache_positions
which is a bit strange!
Let's keep the cache positions, gradually remove calls to get_seq_length
, and explicitly pass the arguments to the model instead of relying on the model.
src/transformers/cache_utils.py
Outdated
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
# limit the check to the first batch member and head dimension. | ||
return (self.key_cache[0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- if part of the previous cache happens to be zero tensor (very rare in full precision, maybe not in half precision, we had a similar issue with llava when relying on tensors being zeros)
- cost is ok I guess
- it's totally implicit 😢 and that's less aligned with our overall philosophy.
|
||
# `torch.compile`-friendly `torch.arange` from a shape | ||
cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the cache positions are using internally anyways, then that means we still have to create them at some point.
cache_positions
required in the public classescache_positions
input
@ArthurZucker as we talked on Slack, the latest set of commits makes NOTE: ignore the changes in In summary, this PR:
|
past_length = ( | ||
cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length() | ||
) | ||
max_cache_length = past_key_values.get_max_length() | ||
cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This restructure prioritizes cache_position
, falling back to .get_seq_length()
in its absence. This replaces seen_tokens
, which is now deprecated.
Note that past_length
[all seen tokens] and cache_length
[tokens in the cache] are both needed, otherwise SinkCache
won't work.
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] | ||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) | ||
position_ids = position_ids.contiguous() if position_ids is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is already on main (see here), not sure why this shows up 👀
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
@gante could you rebase / fix meges and I ll review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Late review, not sure what is the status.
Overall the goal is to not have to pass cache positions, but the logic is very complicated, while asking users to pass the cache positions seems a lot simpler no?
@property | ||
def seen_tokens(self): | ||
logger.warning_once( | ||
"The `seen_tokens` attribute is deprecated and will be removed in v4.40. Use the `cache_position` " | ||
"variable instead." | ||
) | ||
if hasattr(self, "_seen_tokens"): | ||
return self._seen_tokens | ||
else: | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice 😉
if use_cache: | ||
static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) | ||
if static_cache is not None: | ||
past_seen_tokens = static_cache.get_seq_length() | ||
else: | ||
if not isinstance(past_key_values, Cache): | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_seen_tokens = past_key_values.get_seq_length() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a lot of work 😓
Does not seem like this is needed? Two cases:
- No cache positions -> not using generate or not using cache positions -> use the DynamicCache, thus the previous code works for the past length
- cache positions -> use them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should go towards everybody should pass the cache positions and we should not use past_seen_tokens = static_cache.get_seq_length()
.
|
||
if cache_position is None: | ||
cache_position = torch.arange( | ||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | ||
# `torch.compile`-friendly `torch.arange` from a shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does that also fix the ONNX export we had?
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_seen_tokens = past_key_values.get_seq_length() | ||
if use_cache: | ||
static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we broke AWQ a few times with this, let's check generation_config.cache_implementation
?
Closing this PR and other cache PRs, as we want to move in the opposite direction (static cache behaving like the other caches) |
What does this PR do?
Removes the hard requirement of
cache_position
from the public model classes (e.g. classes that can be loaded withAutoModel
). Its contents can be derived from the cache'sget_seq_length()
.Performance and Quality checks
Slow tests run, getting the same outcome as in
main
:RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv
RUN_SLOW=1 py.test tests/models/gemma/test_modeling_gemma.py -vv
RUN_SLOW=1 py.test tests/test_cache_utils.py -vv
[Note: two new tests were added here]👉 Note that these tests ensure that
torch.compile
works, with and withoutattention_mask
being passed.Local benchmark (RTX3090, tiny llama) -- no changes
(main)
(this pr)