From 78a2b19fc84ed55c65f4bf20a901edb7ceb73c5f Mon Sep 17 00:00:00 2001 From: "JB (Don)" <1557853+hackyon@users.noreply.github.com> Date: Fri, 30 Jun 2023 21:19:39 +0900 Subject: [PATCH] Show a warning for missing attention masks when pad_token_id is not None (#24510) * Adding warning messages to BERT for missing attention masks These warning messages when there are pad tokens within the input ids and no attention masks are given. The warning message should only show up once. * Adding warning messages to BERT for missing attention masks These warning messages are shown when the pad_token_id is not None and no attention masks are given. The warning message should only show up once. * Ran fix copies to copy over the changes to some of the other models * Add logger.warning_once.cache_clear() to the test * Shows warning when there are no attention masks and input_ids start/end with pad tokens * Using warning_once() instead and fix indexing in input_ids check --------- Co-authored-by: JB Lau --- src/transformers/modeling_utils.py | 30 ++++++++ .../models/altclip/modeling_altclip.py | 1 + src/transformers/models/bert/modeling_bert.py | 1 + .../bridgetower/modeling_bridgetower.py | 1 + .../models/camembert/modeling_camembert.py | 1 + src/transformers/models/clap/modeling_clap.py | 1 + .../models/data2vec/modeling_data2vec_text.py | 1 + .../models/roberta/modeling_roberta.py | 1 + .../xlm_roberta/modeling_xlm_roberta.py | 1 + .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 1 + tests/models/bert/test_modeling_bert.py | 26 ++++++- tests/test_modeling_utils.py | 76 +++++++++++++++++++ 12 files changed, 140 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4fb8043f4410..4c761bc3119a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3477,6 +3477,36 @@ def reverse_bettertransformer(self): return BetterTransformer.reverse(self) + def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): + """ + Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given. + """ + if (attention_mask is not None) or (self.config.pad_token_id is None): + return + + # Check only the first and last input IDs to reduce overhead. + if self.config.pad_token_id in input_ids[:, [-1, 0]]: + warn_string = ( + "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See " + "https://huggingface.co/docs/transformers/troubleshooting" + "#incorrect-output-when-padding-tokens-arent-masked." + ) + + # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an + # attention_mask or not. In this case, we should still show a warning because this is a rare case. + if ( + (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) + or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) + or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) + ): + warn_string += ( + f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " + f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), " + f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded." + ) + + logger.warning_once(warn_string) + PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) if PreTrainedModel.push_to_hub.__doc__ is not None: diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index fe2754cac808..90188c044e4e 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -1305,6 +1305,7 @@ def forward( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 17667e8443dd..452741c6d8fb 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -967,6 +967,7 @@ def forward( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 1fb3cc131bc8..42b31c964f0c 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -1118,6 +1118,7 @@ def forward( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index ed3afab11aa4..f0b40f56de5c 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -842,6 +842,7 @@ def forward( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 0f3986ada0ce..c5533f3dae00 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1854,6 +1854,7 @@ def forward( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 4c07acd11072..ed79b021fbf1 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -791,6 +791,7 @@ def forward( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index cf71ceba7c45..0b19804dccf2 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -789,6 +789,7 @@ def forward( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 881f60875dbb..14e2e22086a9 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -791,6 +791,7 @@ def forward( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 4299880e0c4f..7b6c15033cfc 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -757,6 +757,7 @@ def forward( raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: diff --git a/tests/models/bert/test_modeling_bert.py b/tests/models/bert/test_modeling_bert.py index efd540902084..db021740714f 100644 --- a/tests/models/bert/test_modeling_bert.py +++ b/tests/models/bert/test_modeling_bert.py @@ -18,7 +18,7 @@ from transformers import BertConfig, is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import CaptureLogger, require_torch, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -40,6 +40,7 @@ BertForTokenClassification, BertLMHeadModel, BertModel, + logging, ) from transformers.models.bert.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST @@ -567,6 +568,29 @@ def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + def test_for_warning_if_padding_and_no_attention_mask(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.model_tester.prepare_config_and_inputs() + + # Set pad tokens in the input_ids + input_ids[0, 0] = config.pad_token_id + + # Check for warnings if the attention_mask is missing. + logger = logging.get_logger("transformers.modeling_utils") + with CaptureLogger(logger) as cl: + model = BertModel(config=config) + model.to(torch_device) + model.eval() + model(input_ids, attention_mask=None, token_type_ids=token_type_ids) + self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out) + @slow def test_model_from_pretrained(self): for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 17ddf1963a28..5019d0ccb308 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -938,6 +938,82 @@ def test_unexpected_keys_warnings(self): self.assertIn("were not used when initializing ModelWithHead: ['added_key']", cl.out) self.assertEqual(loading_info["unexpected_keys"], ["added_key"]) + def test_warn_if_padding_and_no_attention_mask(self): + logger = logging.get_logger("transformers.modeling_utils") + + with self.subTest("Ensure no warnings when pad_token_id is None."): + logger.warning_once.cache_clear() + with CaptureLogger(logger) as cl: + config_no_pad_token = PretrainedConfig() + config_no_pad_token.pad_token_id = None + model = ModelWithHead(config_no_pad_token) + input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]]) + model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) + self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out) + + with self.subTest("Ensure no warnings when there is an attention_mask."): + logger.warning_once.cache_clear() + with CaptureLogger(logger) as cl: + config = PretrainedConfig() + config.pad_token_id = 0 + model = ModelWithHead(config) + input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]]) + attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]) + model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out) + + with self.subTest("Ensure no warnings when there are no pad_token_ids in the input_ids."): + logger.warning_once.cache_clear() + with CaptureLogger(logger) as cl: + config = PretrainedConfig() + config.pad_token_id = 0 + model = ModelWithHead(config) + input_ids = torch.tensor([[1, 345, 232, 328, 740, 140, 1695, 69, 6078, 2341, 25]]) + model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) + self.assertNotIn("We strongly recommend passing in an `attention_mask`", cl.out) + + with self.subTest("Ensure a warning is shown when the input_ids start with a pad_token_id."): + logger.warning_once.cache_clear() + with CaptureLogger(logger) as cl: + config = PretrainedConfig() + config.pad_token_id = 0 + model = ModelWithHead(config) + input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]]) + model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) + self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out) + + with self.subTest("Ensure a warning is shown when the input_ids end with a pad_token_id."): + logger.warning_once.cache_clear() + with CaptureLogger(logger) as cl: + config = PretrainedConfig() + config.pad_token_id = 0 + model = ModelWithHead(config) + input_ids = torch.tensor([[432, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]]) + model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) + self.assertIn("We strongly recommend passing in an `attention_mask`", cl.out) + + with self.subTest("Ensure that the warning is shown at most once."): + logger.warning_once.cache_clear() + with CaptureLogger(logger) as cl: + config = PretrainedConfig() + config.pad_token_id = 0 + model = ModelWithHead(config) + input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]]) + model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) + model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) + self.assertEqual(cl.out.count("We strongly recommend passing in an `attention_mask`"), 1) + + with self.subTest("Ensure a different warning is shown when the pad_token_id is equal to the bos_token_id."): + logger.warning_once.cache_clear() + with CaptureLogger(logger) as cl: + config = PretrainedConfig() + config.pad_token_id = 0 + config.bos_token_id = config.pad_token_id + model = ModelWithHead(config) + input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 0, 0]]) + model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) + self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out) + @require_torch_gpu @slow def test_pretrained_low_mem_new_config(self):