From e9ca1ea2ce98fee2dd9b5716c77ac5e96007a34e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 18:07:40 +0000 Subject: [PATCH 01/12] no cache positions in the public api --- .../models/llama/modeling_llama.py | 45 ++++------- tests/test_cache_utils.py | 74 ++++++++++++++++++- 2 files changed, 88 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8e494adefc2d..a945805b1355 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -357,7 +357,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -446,7 +446,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -625,7 +625,6 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -645,7 +644,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -943,7 +942,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -972,15 +970,21 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: + if position_ids is None: if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange( + raise ValueError("position_ids is a required argument when using StaticCache.") + position_ids = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + ).unsqueeze(0) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + # One of the rows in `position_ids` contains the highest sequence of cache indexes, excluding left-padding + # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. + cache_position = torch.max(position_ids, dim=0).values + if attention_mask is None: + padded_positions = 0 + else: + padded_positions = torch.sum(attention_mask == 0, dim=1).min() + cache_position = cache_position + padded_positions causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) @@ -1130,7 +1134,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1174,7 +1177,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, ) hidden_states = outputs[0] @@ -1248,24 +1250,10 @@ def prepare_inputs_for_generation( if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + position_ids.masked_fill_(attention_mask == 0, 0) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1278,7 +1266,6 @@ def prepare_inputs_for_generation( model_inputs.update( { "position_ids": position_ids.contiguous(), - "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 6d31d63e82ef..0b194417bb5e 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -291,7 +291,7 @@ def test_sink_cache_iterative_prompts(self): @require_torch_gpu @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) - def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): + def test_static_cache_greedy_decoding_pad_left(self, attn_implementation): EXPECTED_GENERATION = [ "The best color is the one that complements the skin tone of the", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -331,7 +331,7 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): @require_torch_gpu @parameterized.expand(["eager", "sdpa", "flash_attention_2"]) - def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): + def test_static_cache_greedy_decoding_pad_right(self, attn_implementation): EXPECTED_GENERATION = [ "The best color isЋ the one that complements the skin tone of", "We should not undermind the issues at hand.\nWe should not undermind the issues", @@ -382,6 +382,76 @@ def call(input_ids, **kwargs): with self.subTest(f"{attn_implementation}, static, compiled"): self.assertListEqual(decoded, EXPECTED_GENERATION) + def test_dynamic_cache_extra_left_padding(self): + """Tests that adding extra left-padding does not affect the generation with the dynamic cache""" + EXPECTED_GENERATION = [ + "The best color is the one that complements the skin tone of the", + "We should not undermind the issues at hand.\nWe should not undermind the issues", + ] + + tokenizer = AutoTokenizer.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + ) + model = AutoModelForCausalLM.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", + torch_dtype=torch.bfloat16, + ).to(torch_device) + inputs = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" + ).to(model.device) + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + # Now with extra left-padding + inputs_expanded = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], + padding=True, + return_tensors="pt", + pad_to_multiple_of=32, + ).to(model.device) + self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1]) + gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + def test_static_cache_extra_left_padding(self): + """Tests that adding extra left-padding does not affect the generation with the static cache""" + EXPECTED_GENERATION = [ + "The best color is the one that complements the skin tone of the", + "We should not undermind the issues at hand.\nWe should not undermind the issues", + ] + + tokenizer = AutoTokenizer.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="" + ) + model = AutoModelForCausalLM.from_pretrained( + "NousResearch/Llama-2-7b-chat-hf", + torch_dtype=torch.bfloat16, + ).to(torch_device) + inputs = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" + ).to(model.device) + + model.generation_config.cache_implementation = "static" + + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + + # Now with extra left-padding + inputs_expanded = tokenizer( + ["The best color is", "We should not undermind the issues at hand"], + padding=True, + return_tensors="pt", + pad_to_multiple_of=32, + ).to(model.device) + self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1]) + gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) + @unittest.skip("TODO @gante static cache's does not support beam search yet") def test_static_cache_beam_search(self): pass From 3b7fbfbb1e82ac741beb280b43fc3f8a15978f45 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 18:14:24 +0000 Subject: [PATCH 02/12] propagate changes to gemma --- src/transformers/generation/utils.py | 7 +-- .../models/gemma/modeling_gemma.py | 44 +++++++------------ 2 files changed, 19 insertions(+), 32 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d337e5593440..d878fc8ebdd4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -678,8 +678,6 @@ def _update_model_kwargs_for_generation( dim=-1, ) - model_kwargs["cache_position"] = model_inputs.get("cache_position", None) - return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): @@ -4931,9 +4929,8 @@ def _split_model_inputs( # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a # ModelOutput object. # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] - keys_to_ignore = ["cache_position", "encoder_outputs"] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] + bool_keys = [k for k in keys if isinstance(model_input[k], bool)] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k != "encoder_outputs"] # we split the tensors and tuples of tensors data_split_list = [ diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 165ef5a05451..fd58d1bb1765 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -247,7 +247,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -334,7 +334,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -533,7 +533,7 @@ def forward( past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - # sin and cos are specific to RoPE models; position_ids needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -835,7 +835,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -864,13 +863,21 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - if cache_position is None: - cache_position = torch.arange( + if position_ids is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("position_ids is a required argument when using StaticCache.") + position_ids = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + ).unsqueeze(0) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) + # One of the rows in `position_ids` contains the highest sequence of cache indexes, excluding left-padding + # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. + cache_position = torch.max(position_ids, dim=0).values + if attention_mask is None: + padded_positions = 0 + else: + padded_positions = torch.sum(attention_mask == 0, dim=1).min() + cache_position = cache_position + padded_positions causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) @@ -1025,7 +1032,6 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1069,7 +1075,6 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, ) hidden_states = outputs[0] @@ -1137,24 +1142,10 @@ def prepare_inputs_for_generation( if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + position_ids.masked_fill_(attention_mask == 0, 0) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None: - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1167,7 +1158,6 @@ def prepare_inputs_for_generation( model_inputs.update( { "position_ids": position_ids.contiguous(), - "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, From 694b26580c7d159df4dc8261815afdf8a6f0221f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 18:19:03 +0000 Subject: [PATCH 03/12] should not have been deleted --- src/transformers/models/llama/modeling_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a945805b1355..e2fa5249fb41 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -625,6 +625,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() From 75aebbe0e4f9fed6a73eb74c2e23d58b59bb06bc Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 19:03:44 +0000 Subject: [PATCH 04/12] more precise padded offset calculation --- src/transformers/models/gemma/modeling_gemma.py | 9 ++++++--- src/transformers/models/llama/modeling_llama.py | 9 ++++++--- tests/test_cache_utils.py | 6 +++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index fd58d1bb1765..133049dad52a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -874,10 +874,13 @@ def forward( # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. cache_position = torch.max(position_ids, dim=0).values if attention_mask is None: - padded_positions = 0 + padded_offset = 0 else: - padded_positions = torch.sum(attention_mask == 0, dim=1).min() - cache_position = cache_position + padded_positions + padded_offset = (1 - torch.sum(attention_mask, dim=0).clamp(max=1)).cumsum(-1) + padded_offset = torch.cat( + (torch.zeros((1,), dtype=padded_offset.dtype, device=padded_offset.device), padded_offset) + )[-cache_position.shape[0] - 1 : -1] + cache_position = cache_position + padded_offset causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e2fa5249fb41..6afd2c08cdfb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -982,10 +982,13 @@ def forward( # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. cache_position = torch.max(position_ids, dim=0).values if attention_mask is None: - padded_positions = 0 + padded_offset = 0 else: - padded_positions = torch.sum(attention_mask == 0, dim=1).min() - cache_position = cache_position + padded_positions + padded_offset = (1 - torch.sum(attention_mask, dim=0).clamp(max=1)).cumsum(-1) + padded_offset = torch.cat( + (torch.zeros((1,), dtype=padded_offset.dtype, device=padded_offset.device), padded_offset) + )[-cache_position.shape[0] - 1 : -1] + cache_position = cache_position + padded_offset causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 0b194417bb5e..a134e916630b 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -436,9 +436,9 @@ def test_static_cache_extra_left_padding(self): model.generation_config.cache_implementation = "static" - gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - self.assertListEqual(decoded, EXPECTED_GENERATION) + # gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + # decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + # self.assertListEqual(decoded, EXPECTED_GENERATION) # Now with extra left-padding inputs_expanded = tokenizer( From 88d597b88bb0c948ae985d3a252b5ab341fd6cb3 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 Feb 2024 19:20:14 +0000 Subject: [PATCH 05/12] attention mask dtype is sometimes wrong in the tests --- src/transformers/models/gemma/modeling_gemma.py | 4 ++-- src/transformers/models/llama/modeling_llama.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 133049dad52a..70bd229d8c8b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -876,9 +876,9 @@ def forward( if attention_mask is None: padded_offset = 0 else: - padded_offset = (1 - torch.sum(attention_mask, dim=0).clamp(max=1)).cumsum(-1) + padded_offset = (1 - torch.sum(attention_mask.to(torch.int64), dim=0).clamp(max=1)).cumsum(-1) padded_offset = torch.cat( - (torch.zeros((1,), dtype=padded_offset.dtype, device=padded_offset.device), padded_offset) + (torch.zeros((1,), dtype=torch.int64, device=padded_offset.device), padded_offset) )[-cache_position.shape[0] - 1 : -1] cache_position = cache_position + padded_offset diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 6afd2c08cdfb..f147fb6d844c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -984,9 +984,9 @@ def forward( if attention_mask is None: padded_offset = 0 else: - padded_offset = (1 - torch.sum(attention_mask, dim=0).clamp(max=1)).cumsum(-1) + padded_offset = (1 - torch.sum(attention_mask.to(torch.int64), dim=0).clamp(max=1)).cumsum(-1) padded_offset = torch.cat( - (torch.zeros((1,), dtype=padded_offset.dtype, device=padded_offset.device), padded_offset) + (torch.zeros((1,), dtype=torch.int64, device=padded_offset.device), padded_offset) )[-cache_position.shape[0] - 1 : -1] cache_position = cache_position + padded_offset From e499ac9bd60a65efbbc3e1b244b0e2ac74b26f1f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 Feb 2024 18:41:02 +0000 Subject: [PATCH 06/12] get_seq_length() working --- src/transformers/cache_utils.py | 13 ++--- .../models/llama/modeling_llama.py | 48 +++++++++++-------- tests/test_cache_utils.py | 6 +-- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 87d24c6cf663..1ed5780a5ef1 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -398,16 +398,9 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" - # TODO: Fix once the stateful `int` bug in PyTorch is fixed. - raise ValueError( - "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." - ) - - def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: - # TODO: Fix once the stateful `int` bug in PyTorch is fixed. - raise ValueError( - "get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114." - ) + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # check the first batch member and the first head only. + return (self.key_cache[0, 0].sum(dim=-1) != 0).sum() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f147fb6d844c..4e71fcffbba4 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -966,29 +966,22 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): + if use_cache: + if past_key_values is not None and not isinstance(past_key_values, Cache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # non-static cache + if past_key_values is not None: past_seen_tokens = past_key_values.get_seq_length() - + # static cache + elif past_key_values is None: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 if position_ids is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("position_ids is a required argument when using StaticCache.") - position_ids = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ).unsqueeze(0) - - # One of the rows in `position_ids` contains the highest sequence of cache indexes, excluding left-padding - # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. - cache_position = torch.max(position_ids, dim=0).values - if attention_mask is None: - padded_offset = 0 - else: - padded_offset = (1 - torch.sum(attention_mask.to(torch.int64), dim=0).clamp(max=1)).cumsum(-1) - padded_offset = torch.cat( - (torch.zeros((1,), dtype=torch.int64, device=padded_offset.device), padded_offset) - )[-cache_position.shape[0] - 1 : -1] - cache_position = cache_position + padded_offset + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) @@ -1220,12 +1213,22 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + # With static cache, the `past_key_values` is None + has_static_cache = False + if past_key_values is None: + has_static_cache = True + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() + # TODO joao: find a better way to track the total number of tokens seen in the static cache + if max_cache_length is not None: + past_length = cache_length + else: + past_length = past_key_values.seen_tokens else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1267,6 +1270,9 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + if has_static_cache: + past_key_values = None + model_inputs.update( { "position_ids": position_ids.contiguous(), diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index a134e916630b..0b194417bb5e 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -436,9 +436,9 @@ def test_static_cache_extra_left_padding(self): model.generation_config.cache_implementation = "static" - # gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) - # decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) - # self.assertListEqual(decoded, EXPECTED_GENERATION) + gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10) + decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) + self.assertListEqual(decoded, EXPECTED_GENERATION) # Now with extra left-padding inputs_expanded = tokenizer( From 6cc17ecf7fd9770d202bd5f8707169240e178757 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 Feb 2024 19:05:23 +0000 Subject: [PATCH 07/12] nits --- src/transformers/cache_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1ed5780a5ef1..250b25d5b010 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -399,8 +399,8 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # check the first batch member and the first head only. - return (self.key_cache[0, 0].sum(dim=-1) != 0).sum() + # limit the check to the first batch member and head dimension. + return (self.key_cache[0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" From 6e4b511e214ea7e0436d7a88bac0c3fbc9ece547 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 Feb 2024 19:20:18 +0000 Subject: [PATCH 08/12] gemma --- .../models/gemma/modeling_gemma.py | 47 ++++++++++--------- .../models/llama/modeling_llama.py | 15 +++--- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 70bd229d8c8b..9bf9dd87c040 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -858,29 +858,19 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + else: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 if position_ids is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("position_ids is a required argument when using StaticCache.") - position_ids = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ).unsqueeze(0) - - # One of the rows in `position_ids` contains the highest sequence of cache indexes, excluding left-padding - # applied on all batch members. Left-padding on all batch members can be detected from the `attention_mask`. - cache_position = torch.max(position_ids, dim=0).values - if attention_mask is None: - padded_offset = 0 - else: - padded_offset = (1 - torch.sum(attention_mask.to(torch.int64), dim=0).clamp(max=1)).cumsum(-1) - padded_offset = torch.cat( - (torch.zeros((1,), dtype=torch.int64, device=padded_offset.device), padded_offset) - )[-cache_position.shape[0] - 1 : -1] - cache_position = cache_position + padded_offset + position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) @@ -1111,12 +1101,22 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + # With static cache, the `past_key_values` is None + has_static_cache = False + if past_key_values is None: + has_static_cache = True + past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens max_cache_length = past_key_values.get_max_length() + # TODO joao: find a better way to track the total number of tokens seen in the static cache + if max_cache_length is not None: + past_length = cache_length + else: + past_length = past_key_values.seen_tokens else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None @@ -1158,6 +1158,9 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + if has_static_cache: + past_key_values = None + model_inputs.update( { "position_ids": position_ids.contiguous(), diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4e71fcffbba4..a5479c99ba43 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -967,16 +967,13 @@ def forward( past_seen_tokens = 0 if use_cache: - if past_key_values is not None and not isinstance(past_key_values, Cache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - # non-static cache - if past_key_values is not None: + static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) + if static_cache is not None: + past_seen_tokens = static_cache.get_seq_length() + else: + if not isinstance(past_key_values, Cache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - # static cache - elif past_key_values is None: - static_cache = getattr(self.layers[0].self_attn, "past_key_value", None) - if static_cache is not None: - past_seen_tokens = static_cache.get_seq_length() # `torch.compile`-friendly `torch.arange` from a shape cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 From 232da2a1e71427dd01b1a66aeeff4e939e7576b5 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 23 Feb 2024 19:36:27 +0000 Subject: [PATCH 09/12] bc nit --- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9bf9dd87c040..155e1b9c18e4 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1145,7 +1145,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) + position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a5479c99ba43..b2f266861ecf 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1254,7 +1254,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) + position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] From 04d53a75a65de6a0c625bd751d063519d09baef9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 28 Feb 2024 15:21:29 +0000 Subject: [PATCH 10/12] explicit cache_positions (implicit working when not passed) --- src/transformers/cache_utils.py | 2 ++ src/transformers/generation/utils.py | 7 +++++-- src/transformers/models/gemma/modeling_gemma.py | 17 +++++++++++++++-- src/transformers/models/llama/modeling_llama.py | 17 +++++++++++++++-- 4 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 250b25d5b010..382fef1085e9 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -400,6 +400,8 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. + # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after + # https://github.com/pytorch/pytorch/issues/120248 is fixed return (self.key_cache[0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index d878fc8ebdd4..d337e5593440 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -678,6 +678,8 @@ def _update_model_kwargs_for_generation( dim=-1, ) + model_kwargs["cache_position"] = model_inputs.get("cache_position", None) + return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): @@ -4929,8 +4931,9 @@ def _split_model_inputs( # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a # ModelOutput object. # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool)] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k != "encoder_outputs"] + bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] + keys_to_ignore = ["cache_position", "encoder_outputs"] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] # we split the tensors and tuples of tensors data_split_list = [ diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 155e1b9c18e4..8a82896f3dcc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -835,6 +835,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -867,8 +868,11 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - # `torch.compile`-friendly `torch.arange` from a shape - cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 + if cache_position is None: + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = ( + torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1025,6 +1029,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1068,6 +1073,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1149,6 +1155,12 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + position_ids = position_ids.contiguous() if position_ids is not None else None + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1164,6 +1176,7 @@ def prepare_inputs_for_generation( model_inputs.update( { "position_ids": position_ids.contiguous(), + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b2f266861ecf..56156f291682 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -943,6 +943,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -975,8 +976,11 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() - # `torch.compile`-friendly `torch.arange` from a shape - cache_position = torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 + if cache_position is None: + # `torch.compile`-friendly `torch.arange` from a shape + cache_position = ( + torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1128,6 +1132,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1171,6 +1176,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1258,6 +1264,12 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + position_ids = position_ids.contiguous() if position_ids is not None else None + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1273,6 +1285,7 @@ def prepare_inputs_for_generation( model_inputs.update( { "position_ids": position_ids.contiguous(), + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, From 20baebdfab8d0dd34e1f0d9e385dd9895379e090 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 28 Feb 2024 16:31:40 +0000 Subject: [PATCH 11/12] add test for implicit cache_position --- .../models/gemma/modeling_gemma.py | 4 + .../models/llama/modeling_llama.py | 4 + tests/test_modeling_common.py | 74 ++++++++++++++++++- 3 files changed, 81 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8a82896f3dcc..723b16abb8cd 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -782,6 +782,10 @@ def _reset_cache(self): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 56156f291682..f1d3c4245187 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -892,6 +892,10 @@ def _reset_cache(self): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 32f6abcbe3aa..5b63a0667da8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -103,7 +103,7 @@ from safetensors.torch import save_file as safe_save_file from torch import nn - from transformers import MODEL_MAPPING, AdaptiveEmbedding + from transformers import MODEL_MAPPING, AdaptiveEmbedding, StaticCache from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage @@ -3937,6 +3937,78 @@ def test_flash_attn_2_from_config(self): self.assertFalse(fa2_correctly_converted) + @require_torch_gpu + @slow + def test_implicit_cache_position(self): + """ + Tests that passing the correct cache_position yields the same results as passing cache_position=None, i.e. that + inference with implicit cache_position is working. + """ + for model_class in self.all_generative_model_classes: + if not hasattr(model_class, "_setup_cache"): + continue + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(torch_device) + + input_ids = inputs_dict["input_ids"].to(torch_device) + + def run_2_forward_passes_with_cache(model, input_ids, static_cache, compile): + # runs two generate-style forward passes, to ensure cudagraphs need two different values of implicit + # `cache_position` to work correctly + if static_cache: + model._setup_cache( + cache_cls=StaticCache, max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + 1 + ) + + if compile: + model = torch.compile(model, fullgraph=True, mode="reduce-overhead") + + # Implicit cache_positions + logits_implicit = [] + outputs = model(input_ids, cache_position=None) + if static_cache: + self.assertTrue(outputs.past_key_values is None) # sanity check -- it is None with static cache + logits_implicit.append(outputs.logits) + new_input_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1) + outputs = model(new_input_ids, cache_position=None, past_key_values=outputs.past_key_values) + logits_implicit.append(outputs.logits) + + if static_cache: + # Restart the cache + model._reset_cache() + model._setup_cache( + cache_cls=StaticCache, max_batch_size=input_ids.shape[0], max_cache_len=input_ids.shape[1] + 1 + ) + + # Explicit cache_positions + logits_explicit = [] + cache_positions = torch.arange(input_ids.shape[1], dtype=torch.long, device=torch_device) + outputs = model(input_ids, cache_position=cache_positions) + if static_cache: + self.assertTrue(outputs.past_key_values is None) # sanity check -- it is None with static cache + logits_explicit.append(outputs.logits) + cache_positions = cache_positions[-1:] + 1 + new_input_ids = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1) + outputs = model(new_input_ids, cache_position=cache_positions, past_key_values=outputs.past_key_values) + logits_explicit.append(outputs.logits) + + if static_cache: + model._reset_cache() + + # Confirm that explicit and implicity cache_positions yield the same results + for idx in range(len(logits_implicit)): + self.assertTrue(torch.allclose(logits_implicit[idx], logits_explicit[idx])) + + # dynamic cache + run_2_forward_passes_with_cache(model, input_ids, static_cache=False, compile=False) + + # eager static cache + run_2_forward_passes_with_cache(model, input_ids, static_cache=True, compile=False) + + # compiled static cache [to confirm that it works with cuda graphs] + run_2_forward_passes_with_cache(model, input_ids, static_cache=True, compile=True) + global_rng = random.Random() From 646f150ac6453eac8202163091685b66e52da55c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 28 Feb 2024 19:08:55 +0000 Subject: [PATCH 12/12] deprecate seen_tokens --- src/transformers/cache_utils.py | 23 +++++++++++++++---- .../models/gemma/modeling_gemma.py | 17 +++++++------- .../models/llama/modeling_llama.py | 17 +++++++------- 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 382fef1085e9..13bac74c986c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -4,6 +4,10 @@ import torch from .configuration_utils import PretrainedConfig +from .utils import logging + + +logger = logging.get_logger(__name__) @dataclass @@ -57,6 +61,17 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - return max_length - new_seq_length return previous_seq_length + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.40. Use the `cache_position` " + "variable instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + class DynamicCache(Cache): """ @@ -69,7 +84,7 @@ class DynamicCache(Cache): def __init__(self) -> None: self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: """ @@ -121,7 +136,7 @@ def update( """ # Update the number of seen tokens if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self._seen_tokens += key_states.shape[-2] # Update the cache if len(self.key_cache) <= layer_idx: @@ -191,7 +206,7 @@ def __init__(self, window_length: int, num_sink_tokens: int) -> None: self.window_length = window_length self.num_sink_tokens = num_sink_tokens self.cos_sin_cache = {} - self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen @staticmethod def _rotate_half(x): @@ -272,7 +287,7 @@ def update( # Update the number of seen tokens if layer_idx == 0: - self.seen_tokens += key_states.shape[-2] + self._seen_tokens += key_states.shape[-2] # [bsz, num_heads, seq_len, head_dim] if len(self.key_cache) <= layer_idx: diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 723b16abb8cd..45238b2e4dbc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -877,6 +877,7 @@ def forward( cache_position = ( torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1109,24 +1110,24 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False if past_key_values is None: - has_static_cache = True past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() + past_length = ( + cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length() + ) max_cache_length = past_key_values.get_max_length() - # TODO joao: find a better way to track the total number of tokens seen in the static cache - if max_cache_length is not None: - past_length = cache_length - else: - past_length = past_key_values.seen_tokens + cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length)) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f1d3c4245187..c4ec236d9938 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -985,6 +985,7 @@ def forward( cache_position = ( torch.ones_like(inputs_embeds[0, :, 0], dtype=torch.int64).cumsum(0) + past_seen_tokens - 1 ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -1218,24 +1219,24 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs ): # With static cache, the `past_key_values` is None + # TODO joao: standardize interface for the different Cache classes and remove of this if has_static_cache = False if past_key_values is None: - has_static_cache = True past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None) + has_static_cache = past_key_values is not None past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() + past_length = ( + cache_position[-1] + 1 if cache_position is not None else past_key_values.get_seq_length() + ) max_cache_length = past_key_values.get_max_length() - # TODO joao: find a better way to track the total number of tokens seen in the static cache - if max_cache_length is not None: - past_length = cache_length - else: - past_length = past_key_values.seen_tokens + cache_length = past_length if max_cache_length is None else min(max_cache_length, int(past_length)) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None