From dfd31158eefab01952e729588a37c9fcc81f0813 Mon Sep 17 00:00:00 2001 From: Amit Garg Date: Fri, 13 Sep 2024 05:07:19 -0700 Subject: [PATCH] [Phi-3] Bug on stale kv cache (#33129) * fix long seq bug * fixed format * fixed fn copy inconsistency * fix long seq bug * fixed format * fixed fn copy inconsistency * Addressed comments * added a unit test * fixed cache position * Added a warning msg to the forward fn * fixed test case --- src/transformers/models/phi3/modeling_phi3.py | 23 ++++++++++- tests/models/phi3/test_modeling_phi3.py | 41 +++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index f021c6ce2d339d..273b6a8f505e79 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -257,7 +257,7 @@ def __init__(self, dim, config, device=None): @torch.no_grad() def forward(self, x, position_ids, seq_len=None): - seq_len = torch.max(position_ids) + 1 + seq_len = seq_len or torch.max(position_ids) + 1 if seq_len > self.original_max_position_embeddings: ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) else: @@ -1239,6 +1239,15 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' ```""" + if ( + use_cache + and self.config.rope_scaling + and cache_position is not None + and cache_position[0] == self.config.original_max_position_embeddings + ): + logger.warning( + f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed." + ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1295,7 +1304,6 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, @@ -1308,6 +1316,17 @@ def prepare_inputs_for_generation( num_logits_to_keep=None, **kwargs, ): + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index a3f001aba467a0..ce0a71878877b5 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -442,6 +442,47 @@ def test_model_rope_scaling_from_config(self, scaling_type): self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + @parameterized.expand([("longrope",)]) + def test_model_rope_scaling_short_long_factor(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + n_factors = config.hidden_size // config.num_key_value_heads // 2 + config.rope_scaling = { + "type": scaling_type, + "short_factor": [3.0 for _ in range(n_factors)], + "long_factor": [5.0 for _ in range(n_factors)], + } + input_tensor = ids_tensor([1, 4090], config.vocab_size) + model = Phi3ForCausalLM(config) + model.to(torch_device) + model.eval() + generation_args_short = { + "max_length": config.original_max_position_embeddings, + "temperature": 0.0, + "use_cache": True, + "do_sample": False, + "return_dict_in_generate": True, + } + output_with_short_factor = model.generate(input_tensor, **generation_args_short) + keys_with_short_factor = output_with_short_factor.past_key_values[0][0] + generation_args_long = { + "max_length": config.original_max_position_embeddings + 5, + "temperature": 0.0, + "use_cache": True, + "do_sample": False, + "return_dict_in_generate": True, + "output_logits": True, + } + output_with_long_factor = model.generate(input_tensor, **generation_args_long) + keys_with_long_factor = output_with_long_factor.past_key_values[0][0] + last_token_logits = output_with_long_factor.logits[-1][-1] + regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1] + keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :] + + # KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position + self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-2, rtol=1e-2)) + # Last token generated using long factor + self.assertTrue(torch.allclose(last_token_logits, regenerated_last_token_logits, atol=1e-2, rtol=1e-2)) + @slow @require_torch