-
Notifications
You must be signed in to change notification settings - Fork 26.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: handle cache_position
update in generate
#29467
Conversation
f5c91b9
to
572ca8e
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I think Llama is already testing this. Moving fast here
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after | ||
# https://github.com/pytorch/pytorch/issues/120248 is fixed | ||
return (self.key_cache[0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alright, we are deprecating this anyways
@@ -663,7 +662,8 @@ def _update_model_kwargs_for_generation( | |||
dim=-1, | |||
) | |||
|
|||
model_kwargs["cache_position"] = model_inputs.get("cache_position", None) | |||
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: | |||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my single worry here is potential stride, adding a .contiguous()
might be needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've double-checked, it's always (1,)
🤗 (which makes sense, since it's a 1D tensor)
Its shape will indeed be different, at least between prefill and subsequent generation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also set the dtype of the cache positions to int32
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our integers inputs (input_ids
, attention_mask
, ...) are all int64
, I think we should keep a consistent type :p
@@ -790,6 +790,10 @@ def _reset_cache(self): | |||
more detail. | |||
return_dict (`bool`, *optional*): | |||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have correct long typing here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see int64 comment above)
58660e2
to
10360b3
Compare
(rebasing and reruning tests, just in case 🙃 ) |
To resolve error `TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'` introduced by huggingface/transformers#29467
What does this PR do?
Updates
cache_position
ingenerate
, and makes it the primary source for the input position in the models that support them,llama
andgemma
(as opposed to relying onpast_key_values.seen_tokens
).The PR also adds the following related changes:
StaticCache
now supportsget_seq_length()
. This was drawn from Static Cache: no mandatorycache_positions
input #29221, and is needed for.prepare_inputs_for_generation()
retrocompatibility;seen_tokens
attribute enters a deprecation cycle, as it is redundant withcache_positions
(and doesn't work with compilation).This PR is drawn from the diff in #29374, i.e. it is a requirement for
generate
compilation withfullgraph=True
🙌👉 Llama, Gemma, and Cache slow tests ran, no new failures
👉 FWD compilation benchmarks ran, no throughput change