Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: handle cache_position update in generate #29467

Merged
merged 6 commits into from
Mar 14, 2024

Conversation

gante
Copy link
Member

@gante gante commented Mar 5, 2024

What does this PR do?

Updates cache_position in generate, and makes it the primary source for the input position in the models that support them, llama and gemma (as opposed to relying on past_key_values.seen_tokens).

The PR also adds the following related changes:

  1. StaticCache now supports get_seq_length(). This was drawn from Static Cache: no mandatory cache_positions input #29221, and is needed for .prepare_inputs_for_generation() retrocompatibility;
  2. The seen_tokens attribute enters a deprecation cycle, as it is redundant with cache_positions (and doesn't work with compilation).

This PR is drawn from the diff in #29374, i.e. it is a requirement for generate compilation with fullgraph=True 🙌

👉 Llama, Gemma, and Cache slow tests ran, no new failures
👉 FWD compilation benchmarks ran, no throughput change

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante marked this pull request as ready for review March 6, 2024 16:19
@gante gante requested a review from ArthurZucker March 6, 2024 16:19
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I think Llama is already testing this. Moving fast here

Comment on lines +418 to +420
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright, we are deprecating this anyways

@@ -663,7 +662,8 @@ def _update_model_kwargs_for_generation(
dim=-1,
)

model_kwargs["cache_position"] = model_inputs.get("cache_position", None)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my single worry here is potential stride, adding a .contiguous() might be needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've double-checked, it's always (1,) 🤗 (which makes sense, since it's a 1D tensor)

Its shape will indeed be different, at least between prefill and subsequent generation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also set the dtype of the cache positions to int32 wdyt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our integers inputs (input_ids, attention_mask, ...) are all int64, I think we should keep a consistent type :p

@@ -790,6 +790,10 @@ def _reset_cache(self):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have correct long typing here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see int64 comment above)

@gante
Copy link
Member Author

gante commented Mar 14, 2024

(rebasing and reruning tests, just in case 🙃 )

@gante gante merged commit 23db187 into huggingface:main Mar 14, 2024
21 checks passed
@gante gante deleted the update_cache_position branch March 14, 2024 16:35
itsdotscience added a commit to itsdotscience/LLaVA that referenced this pull request Mar 22, 2024
To resolve error `TypeError: LlavaLlamaForCausalLM.forward() got an unexpected keyword argument 'cache_position'` introduced by 

huggingface/transformers#29467
itazap pushed a commit that referenced this pull request May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants