From 466af1a356c6ebd03b544735677d6eb9086dbb44 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 16 May 2023 13:59:53 +0100 Subject: [PATCH] OPT/BioGPT: Improved attention mask shape exception (#23270) --- src/transformers/models/biogpt/modeling_biogpt.py | 6 ++++++ src/transformers/models/opt/modeling_opt.py | 5 +++++ src/transformers/models/opt/modeling_tf_opt.py | 9 +++++++++ 3 files changed, 20 insertions(+) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 4fc7f7b4893120..df6c18182dc272 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -546,6 +546,12 @@ def forward( if attention_mask is None: attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device) + elif attention_mask.shape[1] != past_key_values_length + input_shape[1]: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" + ) + # embed positions positions = self.embed_positions(attention_mask, past_key_values_length) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 94269ffbf0ae44..b6c84777cc1f69 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -642,6 +642,11 @@ def forward( # embed positions if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) causal_attention_mask = self._prepare_decoder_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) diff --git a/src/transformers/models/opt/modeling_tf_opt.py b/src/transformers/models/opt/modeling_tf_opt.py index cd34130228a61a..4f738b1605d458 100644 --- a/src/transformers/models/opt/modeling_tf_opt.py +++ b/src/transformers/models/opt/modeling_tf_opt.py @@ -645,6 +645,15 @@ def call( if attention_mask is None: attention_mask = tf.ones(inputs_embeds.shape[:2], dtype=tf.bool) + else: + tf.debugging.assert_equal( + attention_mask.shape[1], + past_key_values_length + input_shape[1], + message=( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" + ), + ) pos_embeds = self.embed_positions(attention_mask, past_key_values_length)