Skip to content

Commit

Permalink
[Phi-3] Bug on stale kv cache (#33129)
Browse files Browse the repository at this point in the history
* fix long seq bug

* fixed format

* fixed fn copy inconsistency

* fix long seq bug

* fixed format

* fixed fn copy inconsistency

* Addressed comments

* added a unit test

* fixed cache position

* Added a warning msg to the forward fn

* fixed test case
  • Loading branch information
garg-amit authored Sep 13, 2024
1 parent 7a56598 commit dfd3115
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
23 changes: 21 additions & 2 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def __init__(self, dim, config, device=None):

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
seq_len = seq_len or torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else:
Expand Down Expand Up @@ -1239,6 +1239,15 @@ def forward(
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
```"""
if (
use_cache
and self.config.rope_scaling
and cache_position is not None
and cache_position[0] == self.config.original_max_position_embeddings
):
logger.warning(
f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
)

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1295,7 +1304,6 @@ def forward(
attentions=outputs.attentions,
)

# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids,
Expand All @@ -1308,6 +1316,17 @@ def prepare_inputs_for_generation(
num_logits_to_keep=None,
**kwargs,
):
# When the first time input length reached long and short factor switching point, enforce re-compute cache
# It will cause downside of slower at this single token position, however, better than current failure.
if (
past_key_values
and self.config.rope_scaling
and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
):
past_length = cache_position[0]
if past_length <= self.config.original_max_position_embeddings:
past_key_values = None

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
Expand Down
41 changes: 41 additions & 0 deletions tests/models/phi3/test_modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,47 @@ def test_model_rope_scaling_from_config(self, scaling_type):
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))

@parameterized.expand([("longrope",)])
def test_model_rope_scaling_short_long_factor(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
n_factors = config.hidden_size // config.num_key_value_heads // 2
config.rope_scaling = {
"type": scaling_type,
"short_factor": [3.0 for _ in range(n_factors)],
"long_factor": [5.0 for _ in range(n_factors)],
}
input_tensor = ids_tensor([1, 4090], config.vocab_size)
model = Phi3ForCausalLM(config)
model.to(torch_device)
model.eval()
generation_args_short = {
"max_length": config.original_max_position_embeddings,
"temperature": 0.0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
}
output_with_short_factor = model.generate(input_tensor, **generation_args_short)
keys_with_short_factor = output_with_short_factor.past_key_values[0][0]
generation_args_long = {
"max_length": config.original_max_position_embeddings + 5,
"temperature": 0.0,
"use_cache": True,
"do_sample": False,
"return_dict_in_generate": True,
"output_logits": True,
}
output_with_long_factor = model.generate(input_tensor, **generation_args_long)
keys_with_long_factor = output_with_long_factor.past_key_values[0][0]
last_token_logits = output_with_long_factor.logits[-1][-1]
regenerated_last_token_logits = model(output_with_long_factor.sequences[:, :-1]).logits[0][-1]
keys_with_long_factor = keys_with_long_factor[:, :, : config.original_max_position_embeddings - 1, :]

# KV cache is re-computed after reaching the (`config.original_max_position_embeddings`+1)th token position
self.assertFalse(torch.allclose(keys_with_short_factor, keys_with_long_factor, atol=1e-2, rtol=1e-2))
# Last token generated using long factor
self.assertTrue(torch.allclose(last_token_logits, regenerated_last_token_logits, atol=1e-2, rtol=1e-2))


@slow
@require_torch
Expand Down

0 comments on commit dfd3115

Please sign in to comment.