From 5c521c90c8deb04848271191bf9c1bfafb48d2c8 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 16 Aug 2022 13:26:18 +0000 Subject: [PATCH 1/5] Add TF generate kwarg validation --- src/transformers/generation_tf_utils.py | 30 ++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 64cecebbe84838..0664857da2a051 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -579,6 +579,7 @@ def generate( do_sample = do_sample if do_sample is not None else self.config.do_sample if do_sample is False or num_beams == 1: + seed = model_kwargs.pop("seed", None) return self._generate( input_ids=input_ids, max_length=max_length, @@ -601,13 +602,14 @@ def generate( attention_mask=attention_mask, decoder_start_token_id=decoder_start_token_id, use_cache=use_cache, - seed=model_kwargs.pop("seed", None), + seed=seed, output_scores=output_scores, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict_in_generate=return_dict_in_generate, forced_bos_token_id=forced_bos_token_id, forced_eos_token_id=forced_eos_token_id, + model_kwargs=model_kwargs, ) # We cannot generate if the model does not have a LM head @@ -1288,6 +1290,29 @@ def adjust_logits_during_generation( else: return logits + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + # Excludes arguments that are handled before calling any model function + if self.config.is_encoder_decoder: + for key in ["decoder_input_ids"]: + model_kwargs.pop(key, None) + + unused_model_args = [] + model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) + # `kwargs` if often used to handle optional forward pass inputs like `attention_mask`. If + # `prepare_inputs_for_generation` doesn't accept `kwargs`, then a stricter check can be made ;) + if "kwargs" in model_args: + model_args |= set(inspect.signature(self.call).parameters) + for key, value in model_kwargs.items(): + if value is not None and key not in model_args: + unused_model_args.append(key) + + if unused_model_args: + raise ValueError( + f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" + " generate arguments will also show up in this list)" + ) + def _generate( self, input_ids=None, @@ -1483,6 +1508,9 @@ def _generate( # generate sequences without allowing bad_words to be generated outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) ```""" + # 0. Validate model kwargs + self._validate_model_kwargs(model_kwargs.copy()) + # 1. Set generation parameters if not already defined length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping From 41414b919aaeee69ed4e8c3afb01ec868bcb9000 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 16 Aug 2022 13:38:07 +0000 Subject: [PATCH 2/5] derp --- src/transformers/generation_tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation_tf_utils.py b/src/transformers/generation_tf_utils.py index 0664857da2a051..d5f92b51e722ec 100644 --- a/src/transformers/generation_tf_utils.py +++ b/src/transformers/generation_tf_utils.py @@ -609,7 +609,7 @@ def generate( return_dict_in_generate=return_dict_in_generate, forced_bos_token_id=forced_bos_token_id, forced_eos_token_id=forced_eos_token_id, - model_kwargs=model_kwargs, + **model_kwargs, ) # We cannot generate if the model does not have a LM head From 82d085ec556fa6517a2614c69846a2229b40cf53 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 16 Aug 2022 14:08:58 +0000 Subject: [PATCH 3/5] create test_generation_tf_utils.py; Add kwarg check tests --- tests/generation/test_generation_tf_utils.py | 184 +++++++++++++++++++ tests/generation/test_generation_utils.py | 4 +- tests/test_modeling_tf_common.py | 136 -------------- 3 files changed, 186 insertions(+), 138 deletions(-) create mode 100644 tests/generation/test_generation_tf_utils.py diff --git a/tests/generation/test_generation_tf_utils.py b/tests/generation/test_generation_tf_utils.py new file mode 100644 index 00000000000000..26a51f25eb5092 --- /dev/null +++ b/tests/generation/test_generation_tf_utils.py @@ -0,0 +1,184 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +from transformers import is_tf_available +from transformers.testing_utils import require_tf, slow + + +if is_tf_available(): + import tensorflow as tf + + from transformers import tf_top_k_top_p_filtering, TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM, AutoTokenizer + + +@require_tf +class UtilsFunctionsTest(unittest.TestCase): + + # tests whether the top_k_top_p_filtering function behaves as expected + def test_top_k_top_p_filtering(self): + logits = tf.convert_to_tensor( + [ + [ + 8.2220991, # 3rd highest value; idx. 0 + -0.5620044, + 5.23229752, + 4.0386393, + -6.8798378, + -0.54785802, + -3.2012153, + 2.92777176, + 1.88171953, + 7.35341276, # 5th highest value; idx. 9 + 8.43207833, # 2nd highest value; idx. 10 + -9.85711836, + -5.96209236, + -1.13039161, + -7.1115294, + -0.8369633, + -5.3186408, + 7.06427407, + 0.81369344, + -0.82023817, + -5.9179796, + 0.58813443, + -6.99778438, + 4.71551189, + -0.18771637, + 7.44020759, # 4th highest value; idx. 25 + 9.38450987, # 1st highest value; idx. 26 + 2.12662941, + -9.32562038, + 2.35652522, + ], # cummulative prob of 5 highest values <= 0.6 + [ + 0.58425518, + 4.53139238, + -5.57510464, + -6.28030699, + -7.19529503, + -4.02122551, + 1.39337037, + -6.06707057, + 1.59480517, + -9.643119, + 0.03907799, + 0.67231762, + -8.88206726, + 6.27115922, # 4th highest value; idx. 13 + 2.28520723, + 4.82767506, + 4.30421368, + 8.8275313, # 2nd highest value; idx. 17 + 5.44029958, # 5th highest value; idx. 18 + -4.4735794, + 7.38579536, # 3rd highest value; idx. 20 + -2.91051663, + 2.61946077, + -2.5674762, + -9.48959302, + -4.02922645, + -1.35416918, + 9.67702323, # 1st highest value; idx. 27 + -5.89478553, + 1.85370467, + ], # cummulative prob of 5 highest values <= 0.6 + ], + dtype=tf.float32, + ) + + non_inf_expected_idx = tf.convert_to_tensor( + [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]], + dtype=tf.int32, + ) # expected non filtered idx as noted above + + non_inf_expected_output = tf.convert_to_tensor( + [8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023], + dtype=tf.float32, + ) # expected non filtered values as noted above + + output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) + + non_inf_output = output[output != -float("inf")] + non_inf_idx = tf.cast( + tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))), + dtype=tf.int32, + ) + + tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) + tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) + + +@require_tf +class TFGenerationIntegrationTests(unittest.TestCase): + + @slow + def test_generate_tf_function_export(self): + test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") + max_length = 2 + + class DummyModel(tf.Module): + def __init__(self, model): + super(DummyModel, self).__init__() + self.model = model + + @tf.function( + input_signature=( + tf.TensorSpec((None, max_length), tf.int32, name="input_ids"), + tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"), + ), + jit_compile=True, + ) + def serving(self, input_ids, attention_mask): + outputs = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=max_length, + return_dict_in_generate=True, + ) + return {"sequences": outputs["sequences"]} + + dummy_input_ids = [[2, 0], [102, 103]] + dummy_attention_masks = [[1, 0], [1, 1]] + dummy_model = DummyModel(model=test_model) + with tempfile.TemporaryDirectory() as tmp_dir: + tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving}) + serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"] + for batch_size in range(1, len(dummy_input_ids) + 1): + inputs = { + "input_ids": tf.constant(dummy_input_ids[:batch_size]), + "attention_mask": tf.constant(dummy_attention_masks[:batch_size]), + } + tf_func_outputs = serving_func(**inputs)["sequences"] + tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length) + tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs) + + def test_validate_generation_inputs(self): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") + + encoder_input_str = "Hello world" + input_ids = tokenizer(encoder_input_str, return_tensors="tf").input_ids + + # typos are quickly detected (the correct argument is `do_sample`) + with self.assertRaisesRegex(ValueError, "do_samples"): + model.generate(input_ids, do_samples=True) + + # arbitrary arguments that will not be used anywhere are also not accepted + with self.assertRaisesRegex(ValueError, "foo"): + fake_model_kwargs = {"foo": "bar"} + model.generate(input_ids, **fake_model_kwargs) diff --git a/tests/generation/test_generation_utils.py b/tests/generation/test_generation_utils.py index 62a3f588cf471b..e8cb57ccf3c97d 100644 --- a/tests/generation/test_generation_utils.py +++ b/tests/generation/test_generation_utils.py @@ -2704,8 +2704,8 @@ def test_constrained_beam_search_mixin_type_checks(self): model.generate(input_ids, force_words_ids=[[[-1]]]) def test_validate_generation_inputs(self): - tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") - model = AutoModelForSeq2SeqLM.from_pretrained("patrickvonplaten/t5-tiny-random") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + model = AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5") encoder_input_str = "Hello world" input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index a0413afdfa0ec9..f3608f4b225d86 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -75,11 +75,9 @@ TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, BertConfig, TFAutoModel, - TFAutoModelForCausalLM, TFAutoModelForSequenceClassification, TFBertModel, TFSharedEmbeddings, - tf_top_k_top_p_filtering, ) from transformers.generation_tf_utils import ( TFBeamSampleDecoderOnlyOutput, @@ -1824,100 +1822,6 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None): @require_tf class UtilsFunctionsTest(unittest.TestCase): - - # tests whether the top_k_top_p_filtering function behaves as expected - def test_top_k_top_p_filtering(self): - logits = tf.convert_to_tensor( - [ - [ - 8.2220991, # 3rd highest value; idx. 0 - -0.5620044, - 5.23229752, - 4.0386393, - -6.8798378, - -0.54785802, - -3.2012153, - 2.92777176, - 1.88171953, - 7.35341276, # 5th highest value; idx. 9 - 8.43207833, # 2nd highest value; idx. 10 - -9.85711836, - -5.96209236, - -1.13039161, - -7.1115294, - -0.8369633, - -5.3186408, - 7.06427407, - 0.81369344, - -0.82023817, - -5.9179796, - 0.58813443, - -6.99778438, - 4.71551189, - -0.18771637, - 7.44020759, # 4th highest value; idx. 25 - 9.38450987, # 1st highest value; idx. 26 - 2.12662941, - -9.32562038, - 2.35652522, - ], # cummulative prob of 5 highest values <= 0.6 - [ - 0.58425518, - 4.53139238, - -5.57510464, - -6.28030699, - -7.19529503, - -4.02122551, - 1.39337037, - -6.06707057, - 1.59480517, - -9.643119, - 0.03907799, - 0.67231762, - -8.88206726, - 6.27115922, # 4th highest value; idx. 13 - 2.28520723, - 4.82767506, - 4.30421368, - 8.8275313, # 2nd highest value; idx. 17 - 5.44029958, # 5th highest value; idx. 18 - -4.4735794, - 7.38579536, # 3rd highest value; idx. 20 - -2.91051663, - 2.61946077, - -2.5674762, - -9.48959302, - -4.02922645, - -1.35416918, - 9.67702323, # 1st highest value; idx. 27 - -5.89478553, - 1.85370467, - ], # cummulative prob of 5 highest values <= 0.6 - ], - dtype=tf.float32, - ) - - non_inf_expected_idx = tf.convert_to_tensor( - [[0, 0], [0, 9], [0, 10], [0, 25], [0, 26], [1, 13], [1, 17], [1, 18], [1, 20], [1, 27]], - dtype=tf.int32, - ) # expected non filtered idx as noted above - - non_inf_expected_output = tf.convert_to_tensor( - [8.222099, 7.3534126, 8.432078, 7.4402075, 9.38451, 6.271159, 8.827531, 5.4402995, 7.3857956, 9.677023], - dtype=tf.float32, - ) # expected non filtered values as noted above - - output = tf_top_k_top_p_filtering(logits, top_k=10, top_p=0.6, min_tokens_to_keep=4) - - non_inf_output = output[output != -float("inf")] - non_inf_idx = tf.cast( - tf.where(tf.not_equal(output, tf.constant(-float("inf"), dtype=tf.float32))), - dtype=tf.int32, - ) - - tf.debugging.assert_near(non_inf_output, non_inf_expected_output, rtol=1e-12) - tf.debugging.assert_equal(non_inf_idx, non_inf_expected_idx) - def test_cached_files_are_used_when_internet_is_down(self): # A mock response for an HTTP head request to emulate server down response_mock = mock.Mock() @@ -2179,46 +2083,6 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(model.weights, new_model.weights): self.assertTrue(np.allclose(p1.numpy(), p2.numpy())) - def test_generate_tf_function_export(self): - test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") - max_length = 2 - - class DummyModel(tf.Module): - def __init__(self, model): - super(DummyModel, self).__init__() - self.model = model - - @tf.function( - input_signature=( - tf.TensorSpec((None, max_length), tf.int32, name="input_ids"), - tf.TensorSpec((None, max_length), tf.int32, name="attention_mask"), - ), - jit_compile=True, - ) - def serving(self, input_ids, attention_mask): - outputs = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_new_tokens=max_length, - return_dict_in_generate=True, - ) - return {"sequences": outputs["sequences"]} - - dummy_input_ids = [[2, 0], [102, 103]] - dummy_attention_masks = [[1, 0], [1, 1]] - dummy_model = DummyModel(model=test_model) - with tempfile.TemporaryDirectory() as tmp_dir: - tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving}) - serving_func = tf.saved_model.load(tmp_dir).signatures["serving_default"] - for batch_size in range(1, len(dummy_input_ids) + 1): - inputs = { - "input_ids": tf.constant(dummy_input_ids[:batch_size]), - "attention_mask": tf.constant(dummy_attention_masks[:batch_size]), - } - tf_func_outputs = serving_func(**inputs)["sequences"] - tf_model_outputs = test_model.generate(**inputs, max_new_tokens=max_length) - tf.debugging.assert_equal(tf_func_outputs, tf_model_outputs) - @require_tf @is_staging_test From c8a24c5d539628457b18d604cf589d3d6e49abb6 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 16 Aug 2022 14:13:14 +0000 Subject: [PATCH 4/5] make fixup --- tests/generation/test_generation_tf_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/generation/test_generation_tf_utils.py b/tests/generation/test_generation_tf_utils.py index 26a51f25eb5092..a3a27249b57fc0 100644 --- a/tests/generation/test_generation_tf_utils.py +++ b/tests/generation/test_generation_tf_utils.py @@ -125,7 +125,6 @@ def test_top_k_top_p_filtering(self): @require_tf class TFGenerationIntegrationTests(unittest.TestCase): - @slow def test_generate_tf_function_export(self): test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2") From 3eb50722ba7aff1296b9fd35172f7de0abaa0beb Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 2 Sep 2022 13:12:29 +0000 Subject: [PATCH 5/5] make fixup --- tests/generation/test_generation_tf_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_generation_tf_utils.py b/tests/generation/test_generation_tf_utils.py index a3a27249b57fc0..d0d284182b53c3 100644 --- a/tests/generation/test_generation_tf_utils.py +++ b/tests/generation/test_generation_tf_utils.py @@ -23,7 +23,7 @@ if is_tf_available(): import tensorflow as tf - from transformers import tf_top_k_top_p_filtering, TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM, AutoTokenizer + from transformers import AutoTokenizer, TFAutoModelForCausalLM, TFAutoModelForSeq2SeqLM, tf_top_k_top_p_filtering @require_tf