Skip to content

Commit

Permalink
fix merge conflicts between llama and gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Mar 1, 2024
1 parent 6c45f0f commit bf5163f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
6 changes: 1 addition & 5 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,10 +949,6 @@ def forward(
attentions=all_self_attns,
)

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
def _update_causal_mask(self, attention_mask, input_tensor):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
Expand Down Expand Up @@ -1156,7 +1152,7 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
Expand Down
21 changes: 12 additions & 9 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,29 +1261,32 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
if self.generation_config.cache_implementation == "static":
# generation with static cache
past_length = past_key_value.get_seq_length()
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
past_length = 0
else:
past_length = cache_position[-1] + 1
input_ids = input_ids[:, past_length:]
position_ids = position_ids[:, past_length:]

# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + position_ids.shape[-1], device=position_ids.device
)
cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"position_ids": position_ids,
"position_ids": position_ids.contiguous(),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
Expand Down
2 changes: 0 additions & 2 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
CaptureLogger,
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
Expand Down Expand Up @@ -599,7 +598,6 @@ def test_model_13b_greedy_generation(self):

@slow
@require_torch_gpu
@require_read_token
def test_compile_static_cache(self):
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = [
Expand Down

0 comments on commit bf5163f

Please sign in to comment.