-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Static Cache: no mandatory cache_positions
input
#29221
Changes from all commits
e9ca1ea
3b7fbfb
694b265
75aebbe
88d597b
e499ac9
6cc17ec
6e4b511
232da2a
04d53a7
20baebd
646f150
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -247,7 +247,7 @@ def forward( | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) | ||
|
||
if past_key_value is not None: | ||
# sin and cos are specific to RoPE models; position_ids needed for the static cache | ||
# sin and cos are specific to RoPE models; cache_position needed for the static cache | ||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | ||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||
|
||
|
@@ -334,7 +334,7 @@ def forward( | |
past_key_value = getattr(self, "past_key_value", past_key_value) | ||
|
||
if past_key_value is not None: | ||
# sin and cos are specific to RoPE models; position_ids needed for the static cache | ||
# sin and cos are specific to RoPE models; cache_position needed for the static cache | ||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | ||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||
|
||
|
@@ -533,7 +533,7 @@ def forward( | |
past_key_value = getattr(self, "past_key_value", past_key_value) | ||
|
||
if past_key_value is not None: | ||
# sin and cos are specific to RoPE models; position_ids needed for the static cache | ||
# sin and cos are specific to RoPE models; cache_position needed for the static cache | ||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | ||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||
|
||
|
@@ -782,6 +782,10 @@ def _reset_cache(self): | |
more detail. | ||
return_dict (`bool`, *optional*): | ||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | ||
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): | ||
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, | ||
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer | ||
the complete sequence length. | ||
""" | ||
|
||
|
||
|
@@ -859,14 +863,19 @@ def forward( | |
inputs_embeds = self.embed_tokens(input_ids) | ||
|
||
past_seen_tokens = 0 | ||
if use_cache: # kept for BC (cache positions) | ||
if not isinstance(past_key_values, StaticCache): | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_seen_tokens = past_key_values.get_seq_length() | ||
if use_cache: | ||
static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we broke AWQ a few times with this, let's check |
||
if static_cache is not None: | ||
past_seen_tokens = static_cache.get_seq_length() | ||
else: | ||
if not isinstance(past_key_values, Cache): | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_seen_tokens = past_key_values.get_seq_length() | ||
Comment on lines
+866
to
+873
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a lot of work 😓
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO we should go towards everybody should pass the cache positions and we should not use |
||
|
||
if cache_position is None: | ||
cache_position = torch.arange( | ||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | ||
# `torch.compile`-friendly `torch.arange` from a shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does that also fix the ONNX export we had? |
||
cache_position = ( | ||
torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 | ||
) | ||
|
||
if position_ids is None: | ||
|
@@ -1101,14 +1110,24 @@ def forward( | |
) | ||
|
||
def prepare_inputs_for_generation( | ||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | ||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs | ||
): | ||
# With static cache, the `past_key_values` is None | ||
# TODO joao: standardize interface for the different Cache classes and remove of this if | ||
has_static_cache = False | ||
if past_key_values is None: | ||
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) | ||
has_static_cache = past_key_values is not None | ||
|
||
past_length = 0 | ||
if past_key_values is not None: | ||
if isinstance(past_key_values, Cache): | ||
cache_length = past_key_values.get_seq_length() | ||
past_length = past_key_values.seen_tokens | ||
past_length = ( | ||
cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length() | ||
) | ||
max_cache_length = past_key_values.get_max_length() | ||
cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length)) | ||
Comment on lines
+1125
to
+1129
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This restructure prioritizes Note that |
||
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects | ||
else: | ||
cache_length = past_length = past_key_values[0][0].shape[2] | ||
max_cache_length = None | ||
|
@@ -1141,19 +1160,11 @@ def prepare_inputs_for_generation( | |
if past_key_values: | ||
position_ids = position_ids[:, -input_ids.shape[1] :] | ||
|
||
if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: | ||
# generation with static cache | ||
cache_position = kwargs.get("cache_position", None) | ||
if cache_position is None: | ||
past_length = 0 | ||
else: | ||
past_length = cache_position[-1] + 1 | ||
input_ids = input_ids[:, past_length:] | ||
position_ids = position_ids[:, past_length:] | ||
|
||
# TODO @gante we should only keep a `cache_position` in generate, and do +=1. | ||
# same goes for position ids. Could also help with continued generation. | ||
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) | ||
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] | ||
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) | ||
position_ids = position_ids.contiguous() if position_ids is not None else None | ||
Comment on lines
+1165
to
+1167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is already on main (see here), not sure why this shows up 👀 |
||
|
||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | ||
if inputs_embeds is not None and past_key_values is None: | ||
|
@@ -1164,6 +1175,9 @@ def prepare_inputs_for_generation( | |
# TODO: use `next_tokens` directly instead. | ||
model_inputs = {"input_ids": input_ids.contiguous()} | ||
|
||
if has_static_cache: | ||
past_key_values = None | ||
|
||
model_inputs.update( | ||
{ | ||
"position_ids": position_ids.contiguous(), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice 😉