Skip to content

Commit

Permalink
[BC] Fix BC for AWQ quant (#29965)
Browse files Browse the repository at this point in the history
fix awq quant
  • Loading branch information
TechxGenus authored and ArthurZucker committed Apr 1, 2024
1 parent 97c00cd commit 839c2a1
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
target_length = (
Expand Down

0 comments on commit 839c2a1

Please sign in to comment.