Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static Cache: no mandatory cache_positions input #29221

Closed
wants to merge 12 commits into from
38 changes: 24 additions & 14 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import torch

from .configuration_utils import PretrainedConfig
from .utils import logging


logger = logging.get_logger(__name__)


@dataclass
Expand Down Expand Up @@ -57,6 +61,17 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
return max_length - new_seq_length
return previous_seq_length

@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.40. Use the `cache_position` "
"variable instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None
Comment on lines +64 to +73
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice 😉



class DynamicCache(Cache):
"""
Expand All @@ -69,7 +84,7 @@ class DynamicCache(Cache):
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Expand Down Expand Up @@ -121,7 +136,7 @@ def update(
"""
# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
Expand Down Expand Up @@ -191,7 +206,7 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None:
self.window_length = window_length
self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {}
self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

@staticmethod
def _rotate_half(x):
Expand Down Expand Up @@ -272,7 +287,7 @@ def update(

# Update the number of seen tokens
if layer_idx == 0:
self.seen_tokens += key_states.shape[-2]
self._seen_tokens += key_states.shape[-2]

# [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx:
Expand Down Expand Up @@ -398,16 +413,11 @@ def update(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)

def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise ValueError(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
Expand Down
60 changes: 37 additions & 23 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -334,7 +334,7 @@ def forward(
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -533,7 +533,7 @@ def forward(
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -782,6 +782,10 @@ def _reset_cache(self):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""


Expand Down Expand Up @@ -859,14 +863,19 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if use_cache:
static_cache = getattr(self.layers[0].self_attn, "past_key_value", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we broke AWQ a few times with this, let's check generation_config.cache_implementation ?

if static_cache is not None:
past_seen_tokens = static_cache.get_seq_length()
else:
if not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
Comment on lines +866 to +873
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a lot of work 😓
Does not seem like this is needed? Two cases:

  1. No cache positions -> not using generate or not using cache positions -> use the DynamicCache, thus the previous code works for the past length
  2. cache positions -> use them

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should go towards everybody should pass the cache positions and we should not use past_seen_tokens = static_cache.get_seq_length().


if cache_position is None:
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
# `torch.compile`-friendly `torch.arange` from a shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does that also fix the ONNX export we had?

cache_position = (
torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1
)

if position_ids is None:
Expand Down Expand Up @@ -1101,14 +1110,24 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
has_static_cache = past_key_values is not None

past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
past_length = (
cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length()
)
max_cache_length = past_key_values.get_max_length()
cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length))
Comment on lines +1125 to +1129
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This restructure prioritizes cache_position, falling back to .get_seq_length() in its absence. This replaces seen_tokens, which is now deprecated.

Note that past_length [all seen tokens] and cache_length [tokens in the cache] are both needed, otherwise SinkCache won't work.

# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
Expand Down Expand Up @@ -1141,19 +1160,11 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
position_ids = position_ids.contiguous() if position_ids is not None else None
Comment on lines +1165 to +1167
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already on main (see here), not sure why this shows up 👀


# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
Expand All @@ -1164,6 +1175,9 @@ def prepare_inputs_for_generation(
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}

if has_static_cache:
past_key_values = None

model_inputs.update(
{
"position_ids": position_ids.contiguous(),
Expand Down
60 changes: 36 additions & 24 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -446,7 +446,7 @@ def forward(
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -645,7 +645,7 @@ def forward(
past_key_value = getattr(self, "past_key_value", past_key_value)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Expand Down Expand Up @@ -892,6 +892,10 @@ def _reset_cache(self):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""


Expand Down Expand Up @@ -967,16 +971,19 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if use_cache:
static_cache = getattr(self.layers[0].self_attn, "past_key_value", None)
if static_cache is not None:
past_seen_tokens = static_cache.get_seq_length()
else:
if not isinstance(past_key_values, Cache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
# `torch.compile`-friendly `torch.arange` from a shape
cache_position = (
torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1
)

if position_ids is None:
Expand Down Expand Up @@ -1212,14 +1219,24 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
has_static_cache = past_key_values is not None

past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
past_length = (
cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length()
)
max_cache_length = past_key_values.get_max_length()
cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length))
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
Expand Down Expand Up @@ -1252,19 +1269,11 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
position_ids = position_ids.contiguous() if position_ids is not None else None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
Expand All @@ -1275,6 +1284,9 @@ def prepare_inputs_for_generation(
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}

if has_static_cache:
past_key_values = None

model_inputs.update(
{
"position_ids": position_ids.contiguous(),
Expand Down
Loading
Loading