diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 87d24c6cf663..13bac74c986c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -4,6 +4,10 @@ import torch from .configuration_utils import PretrainedConfig +from .utils import logging + + +logger = logging.get_logger(__name__) @dataclass @@ -57,6 +61,17 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - return max_length - new_seq_length return previous_seq_length + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.40. Use the `cache_position` " + "variable instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + class DynamicCache(Cache): """ @@ -69,7 +84,7 @@ class DynamicCache(Cache): def __init__(self) -> None: self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: """ @@ -121,7 +136,7 @@ def update( """ # Update the number of seen tokens if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self._seen_tokens += key_states.shape[-2] # Update the cache if len(self.key_cache) <= layer_idx: @@ -191,7 +206,7 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None: self.window_length = window_length self.num_sink_tokens = num_sink_tokens self.cos_sin_cache = {} - self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen @staticmethod def _rotate_half(x): @@ -272,7 +287,7 @@ def update( # Update the number of seen tokens if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self._seen_tokens += key_states.shape[-2] # [bsz, num_heads, seq_len, head_dim] if len(self.key_cache) <= layer_idx: @@ -398,16 +413,11 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" - # TODO: Fix once the stateful `int` bug in PyTorch is fixed. - raise ValueError( - "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." - ) - - def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: - # TODO: Fix once the stateful `int` bug in PyTorch is fixed. - raise ValueError( - "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." - ) + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after + # https://github.com/pytorch/pytorch/issues/120248 is fixed + return (self.key_cache[0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 165ef5a05451..45238b2e4dbc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -247,7 +247,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -334,7 +334,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -533,7 +533,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -782,6 +782,10 @@ def _reset_cache(self): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -859,14 +863,19 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + else: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() if cache_position is None: - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = ( + torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 ) if position_ids is None: @@ -1101,14 +1110,24 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): + # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if + has_static_cache = False + if past_key_values is None: + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens + past_length = ( + cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length() + ) max_cache_length = past_key_values.get_max_length() + cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length)) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1141,19 +1160,11 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. - cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + position_ids = position_ids.contiguous() if position_ids is not None else None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1164,6 +1175,9 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + if has_static_cache: + past_key_values = None + model_inputs.update( { "position_ids": position_ids.contiguous(), diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8e494adefc2d..c4ec236d9938 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -357,7 +357,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -446,7 +446,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -645,7 +645,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -892,6 +892,10 @@ def _reset_cache(self): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ @@ -967,16 +971,19 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if use_cache: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + else: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = ( + torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 ) if position_ids is None: @@ -1212,14 +1219,24 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): + # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if + has_static_cache = False + if past_key_values is None: + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens + past_length = ( + cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length() + ) max_cache_length = past_key_values.get_max_length() + cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length)) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1252,19 +1269,11 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. # same goes for position ids. Could also help with continued generation. - cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + position_ids = position_ids.contiguous() if position_ids is not None else None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1275,6 +1284,9 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + if has_static_cache: + past_key_values = None + model_inputs.update( { "position_ids": position_ids.contiguous(), diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 6d31d63e82ef..0b194417bb5e 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -291,7 +291,7 @@ def test_sink_cache_iterative_prompts(self): @require_torch_gpu @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) - def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): + def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): EXPECTED_GENERATION = [ "The best color is the one that complements the skin tone of the", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -331,7 +331,7 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): @require_torch_gpu @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) - def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): + def test_static_cache_greedy_decoding_pad_right(self, attn_implementation): EXPECTED_GENERATION = [ "The best color isЋ the one that complements the skin tone of", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -382,6 +382,76 @@ def call(input_ids, **kwargs): with self.subTest(f"{attn_implementation}, static, compiled"): self.assertListEqual(decoded, EXPECTED_GENERATION) + def test_dynamic_cache_extra_left_padding(self): + """Tests that adding extra left-padding does not affect the generation with the dynamic cache""" + EXPECTED_GENERATION = [ + "The best color is the one that complements the skin tone of the", + "We should not undermind the issues at hand.\nWe should not undermind the issues", + ] + + tokenizer = AutoTokenizer.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + ) + model = AutoModelForCausalLM.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", + torch_dtype=torch.bfloat16, + ).to(torch_device) + inputs = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" + ).to(model.device) + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + # Now with extra left-padding + inputs_expanded = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], + padding=True, + return_tensors="pt", + pad_to_multiple_of=32, + ).to(model.device) + self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1]) + gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + def test_static_cache_extra_left_padding(self): + """Tests that adding extra left-padding does not affect the generation with the static cache""" + EXPECTED_GENERATION = [ + "The best color is the one that complements the skin tone of the", + "We should not undermind the issues at hand.\nWe should not undermind the issues", + ] + + tokenizer = AutoTokenizer.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + ) + model = AutoModelForCausalLM.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", + torch_dtype=torch.bfloat16, + ).to(torch_device) + inputs = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" + ).to(model.device) + + model.generation_config.cache_implementation = "static" + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + # Now with extra left-padding + inputs_expanded = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], + padding=True, + return_tensors="pt", + pad_to_multiple_of=32, + ).to(model.device) + self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1]) + gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + @unittest.skip("TODO @gante static cache's does not support beam search yet") def test_static_cache_beam_search(self): pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 32f6abcbe3aa..5b63a0667da8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -103,7 +103,7 @@ from safetensors.torch import save_file as safe_save_file from torch import nn - from transformers import MODEL_MAPPING, AdaptiveEmbedding + from transformers import MODEL_MAPPING, AdaptiveEmbedding, StaticCache from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage @@ -3937,6 +3937,78 @@ def test_flash_attn_2_from_config(self): self.assertFalse(fa2_correctly_converted) + @require_torch_gpu + @slow + def test_implicit_cache_position(self): + """ + Tests that passing the correct cache_position yields the same results as passing cache_position=None, i.e. that + inference with implicit cache_position is working. + """ + for model_class in self.all_generative_model_classes: + if not hasattr(model_class, "_setup_cache"): + continue + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + + input_ids = inputs_dict["input_ids"].to(torch_device) + + def run_2_forward_passes_with_cache(model, input_ids, static_cache, compile): + # runs two generate-style forward passes, to ensure cudagraphs need two different values of implicit + # `cache_position` to work correctly + if static_cache: + model._setup_cache( + cache_cls=StaticCache, max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + 1 + ) + + if compile: + model = torch.compile(model, fullgraph=True, mode="reduce-overhead") + + # Implicit cache_positions + logits_implicit = [] + outputs = model(input_ids, cache_position=None) + if static_cache: + self.assertTrue(outputs.past_key_values is None) # sanity check -- it is None with static cache + logits_implicit.append(outputs.logits) + new_input_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1) + outputs = model(new_input_ids, cache_position=None, past_key_values=outputs.past_key_values) + logits_implicit.append(outputs.logits) + + if static_cache: + # Restart the cache + model._reset_cache() + model._setup_cache( + cache_cls=StaticCache, max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + 1 + ) + + # Explicit cache_positions + logits_explicit = [] + cache_positions = torch.arange(input_ids.shape[1], dtype=torch.long, device=torch_device) + outputs = model(input_ids, cache_position=cache_positions) + if static_cache: + self.assertTrue(outputs.past_key_values is None) # sanity check -- it is None with static cache + logits_explicit.append(outputs.logits) + cache_positions = cache_positions[-1:] + 1 + new_input_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1) + outputs = model(new_input_ids, cache_position=cache_positions, past_key_values=outputs.past_key_values) + logits_explicit.append(outputs.logits) + + if static_cache: + model._reset_cache() + + # Confirm that explicit and implicity cache_positions yield the same results + for idx in range(len(logits_implicit)): + self.assertTrue(torch.allclose(logits_implicit[idx], logits_explicit[idx])) + + # dynamic cache + run_2_forward_passes_with_cache(model, input_ids, static_cache=False, compile=False) + + # eager static cache + run_2_forward_passes_with_cache(model, input_ids, static_cache=True, compile=False) + + # compiled static cache [to confirm that it works with cuda graphs] + run_2_forward_passes_with_cache(model, input_ids, static_cache=True, compile=True) + global_rng = random.Random()