From 1e09237aea7f8675f03085087fc38312dea9c4d0 Mon Sep 17 00:00:00 2001 From: Hiroshi Matsuda <40782025+hiroshi-matsuda-rit@users.noreply.github.com> Date: Wed, 19 Jun 2024 23:11:05 +0900 Subject: [PATCH 1/3] fix bug about add_special_tokens and so on --- src/transformers/pipelines/text_generation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index ca8e5da6ea5004..37729b5f2dda7d 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -167,6 +167,8 @@ def _sanitize_parameters( preprocess_params["handle_long_generation"] = handle_long_generation preprocess_params.update(generate_kwargs) + for preprocess_only_key in ["add_special_tokens", "padding"]: + generate_kwargs.pop(preprocess_only_key, False) forward_params = generate_kwargs postprocess_params = {} From 66be0f89248f435d34bf3f971857315ead066c0f Mon Sep 17 00:00:00 2001 From: Hiroshi Matsuda <40782025+hiroshi-matsuda-rit@users.noreply.github.com> Date: Fri, 21 Jun 2024 06:45:19 +0900 Subject: [PATCH 2/3] improve add_special_tokens and padding behavior --- src/transformers/pipelines/text_generation.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 37729b5f2dda7d..c2dce89dd701be 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -137,11 +137,10 @@ def _sanitize_parameters( add_special_tokens = False if "add_special_tokens" in generate_kwargs: - preprocess_params["add_special_tokens"] = generate_kwargs["add_special_tokens"] - add_special_tokens = generate_kwargs["add_special_tokens"] + add_special_tokens = preprocess_params["add_special_tokens"] = generate_kwargs.pop("add_special_tokens") if "padding" in generate_kwargs: - preprocess_params["padding"] = generate_kwargs["padding"] + preprocess_params["padding"] = generate_kwargs.pop("padding") if truncation is not None: preprocess_params["truncation"] = truncation @@ -167,8 +166,6 @@ def _sanitize_parameters( preprocess_params["handle_long_generation"] = handle_long_generation preprocess_params.update(generate_kwargs) - for preprocess_only_key in ["add_special_tokens", "padding"]: - generate_kwargs.pop(preprocess_only_key, False) forward_params = generate_kwargs postprocess_params = {} From 3e99660227f0130b79527ebe5a0978eab2683f07 Mon Sep 17 00:00:00 2001 From: Hiroshi Matsuda Date: Sat, 22 Jun 2024 22:32:44 +0900 Subject: [PATCH 3/3] add a test case for add_special_tokens and padding --- tests/pipelines/test_pipelines_text_generation.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 542f393b20257d..4c91fd46cd978d 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -107,6 +107,20 @@ def test_small_model_pt(self): ) assert output_str != output_str_with_truncation # results must be different because one had truncation + ## -- test kwargs for preprocess_params + outputs = text_generator("This is a test", do_sample=False, add_special_tokens=False, padding=False) + self.assertEqual( + outputs, + [ + { + "generated_text": ( + "This is a test ☃ ☃ segmental segmental segmental 议议eski eski flutter flutter Lacy oscope." + " oscope. FiliFili@@" + ) + } + ], + ) + # -- what is the point of this test? padding is hardcoded False in the pipeline anyway text_generator.tokenizer.pad_token_id = text_generator.model.config.eos_token_id text_generator.tokenizer.pad_token = ""