From 239665efa797e90ac4b6c848e40ad76a65248fb7 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:16:44 +0200 Subject: [PATCH] Fix missing test in `torch_job` (#33593) fix missing tests Co-authored-by: ydshieh --- tests/generation/test_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 08b40e71cf1f3c..e5a83beac868c4 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -89,7 +89,6 @@ from transformers.generation.utils import _speculative_sampling -@pytest.mark.generate class GenerationTesterMixin: model_tester = None all_generative_model_classes = () @@ -2035,6 +2034,7 @@ def test_generate_compile_fullgraph(self): output_compiled = compiled_generate(model_inputs, generation_config=generation_config) self.assertListEqual(output_dynamic.tolist(), output_compiled.tolist()) + @pytest.mark.generate def test_generate_methods_with_num_logits_to_keep(self): for model_class in self.all_generative_model_classes: if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()): @@ -2063,6 +2063,7 @@ def test_generate_methods_with_num_logits_to_keep(self): ) self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + @pytest.mark.generate @is_flaky() # assisted generation tests are flaky (minor fp ops differences) def test_assisted_decoding_with_num_logits_to_keep(self): for model_class in self.all_generative_model_classes: