From 0e23e60a5ad1be33b5a0ada9e42e1ac273c5e08e Mon Sep 17 00:00:00 2001 From: Hiroshi Matsuda <40782025+hiroshi-matsuda-rit@users.noreply.github.com> Date: Mon, 24 Jun 2024 22:05:16 +0900 Subject: [PATCH] Fix bug about add_special_tokens and so on (#31496) * fix bug about add_special_tokens and so on * improve add_special_tokens and padding behavior * add a test case for add_special_tokens and padding --- src/transformers/pipelines/text_generation.py | 5 ++--- tests/pipelines/test_pipelines_text_generation.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index ca8e5da6ea5004..c2dce89dd701be 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -137,11 +137,10 @@ def _sanitize_parameters( add_special_tokens = False if "add_special_tokens" in generate_kwargs: - preprocess_params["add_special_tokens"] = generate_kwargs["add_special_tokens"] - add_special_tokens = generate_kwargs["add_special_tokens"] + add_special_tokens = preprocess_params["add_special_tokens"] = generate_kwargs.pop("add_special_tokens") if "padding" in generate_kwargs: - preprocess_params["padding"] = generate_kwargs["padding"] + preprocess_params["padding"] = generate_kwargs.pop("padding") if truncation is not None: preprocess_params["truncation"] = truncation diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 542f393b20257d..4c91fd46cd978d 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -107,6 +107,20 @@ def test_small_model_pt(self): ) assert output_str != output_str_with_truncation # results must be different because one had truncation + ## -- test kwargs for preprocess_params + outputs = text_generator("This is a test", do_sample=False, add_special_tokens=False, padding=False) + self.assertEqual( + outputs, + [ + { + "generated_text": ( + "This is a test ☃ ☃ segmental segmental segmental 议议eski eski flutter flutter Lacy oscope." + " oscope. FiliFili@@" + ) + } + ], + ) + # -- what is the point of this test? padding is hardcoded False in the pipeline anyway text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id text_generator.tokenizer.pad_token = ""