Skip to content

Commit

Permalink
[BC] Fix BC for other libraries (#29934)
Browse files Browse the repository at this point in the history
* fi xbc?

* nit
  • Loading branch information
ArthurZucker authored and amyeroberts committed Mar 28, 2024
1 parent 1b6d501 commit 02b1012
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,7 @@ def prepare_inputs_for_generation(
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None

past_length = 0
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def prepare_inputs_for_generation(
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None

past_length = 0
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ def prepare_inputs_for_generation(
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None

past_length = 0
Expand Down

0 comments on commit 02b1012

Please sign in to comment.