Skip to content

Commit

Permalink
[Llama + AWQ] fix prepare_inputs_for_generation 🫠 (#29381)
Browse files Browse the repository at this point in the history
* use the generation config 🫠

* fixup
  • Loading branch information
ArthurZucker authored and Ita Zaporozhets committed May 14, 2024
1 parent 32c1fe0 commit 6dac2e9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,7 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,7 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
Expand Down

0 comments on commit 6dac2e9

Please sign in to comment.