From 5351f0f20c9a3c6cc1a77b741d3b7d5bc2130273 Mon Sep 17 00:00:00 2001 From: GM07 Date: Fri, 19 Jul 2024 12:38:06 -0400 Subject: [PATCH 1/2] Adding arg for custom beam search scorer in generate() --- .../generation/configuration_utils.py | 17 +++ src/transformers/generation/utils.py | 61 +++++++--- tests/generation/test_utils.py | 108 ++++++++++++++++++ 3 files changed, 167 insertions(+), 19 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index c7e626f1a7c2..77899ef6a47e 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -133,6 +133,9 @@ class GenerationConfig(PushToHubMixin): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + beam_search_scorer_class: (`class`, *optional*, defaults to `None`): + Which class to use as a beam search scorer. If `None`, it will use the default `BeamSearchScorer` class. + The type must inherit from `BeamSearchScorer`. > Parameters for manipulation of the model output logits @@ -353,6 +356,7 @@ def __init__(self, **kwargs): self.num_beam_groups = kwargs.pop("num_beam_groups", 1) self.penalty_alpha = kwargs.pop("penalty_alpha", None) self.use_cache = kwargs.pop("use_cache", True) + self.beam_search_scorer_class = kwargs.pop("beam_search_scorer_class", None) # Parameters for manipulation of the model output logits self.temperature = kwargs.pop("temperature", 1.0) @@ -640,6 +644,13 @@ def validate(self, is_init=False): single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints), UserWarning, ) + if self.beam_search_scorer_class is not None: + warnings.warn( + single_beam_wrong_parameter_msg.format( + flag_name="beam_search_scorer_class", flag_value=self.beam_search_scorer_class + ), + UserWarning, + ) # 3. detect incorrect paramaterization specific to advanced beam modes else: @@ -660,6 +671,12 @@ def validate(self, is_init=False): flag_name="num_beam_groups", flag_value=self.num_beam_groups ) ) + if self.beam_search_scorer_class is not None: + raise ValueError( + constrained_wrong_parameter_msg.format( + flag_name="beam_search_scorer_class", flag_value=self.beam_search_scorer_class + ) + ) # group beam search if self.diversity_penalty != 0.0 or self.num_beam_groups != 1: group_error_prefix = ( diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9d3a92d26881..05bad22b8b89 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2006,15 +2006,27 @@ def generate( ) # 12. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=generation_config.num_beams, - device=inputs_tensor.device, - length_penalty=generation_config.length_penalty, - do_early_stopping=generation_config.early_stopping, - num_beam_hyps_to_keep=generation_config.num_return_sequences, - max_length=generation_config.max_length, - ) + if generation_config.beam_search_scorer_class is None: + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, + ) + else: + beam_scorer = generation_config.beam_search_scorer_class( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + max_length=generation_config.max_length, + ) # 13. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( @@ -2038,16 +2050,27 @@ def generate( elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH: # 11. prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=generation_config.num_beams, - device=inputs_tensor.device, - length_penalty=generation_config.length_penalty, - do_early_stopping=generation_config.early_stopping, - num_beam_hyps_to_keep=generation_config.num_return_sequences, - num_beam_groups=generation_config.num_beam_groups, - max_length=generation_config.max_length, - ) + if generation_config.beam_search_scorer_class is None: + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + max_length=generation_config.max_length, + ) + else: + beam_scorer = generation_config.beam_search_scorer_class( + batch_size=batch_size, + num_beams=generation_config.num_beams, + device=inputs_tensor.device, + length_penalty=generation_config.length_penalty, + do_early_stopping=generation_config.early_stopping, + num_beam_hyps_to_keep=generation_config.num_return_sequences, + num_beam_groups=generation_config.num_beam_groups, + max_length=generation_config.max_length, + ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b21183618897..074ffeffee64 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -65,6 +65,7 @@ BeamSampleEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, + BeamSearchScorer, DisjunctiveConstraint, GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput, @@ -580,6 +581,59 @@ def test_beam_search_generate(self): else: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1]) + def test_custom_beam_search_scorer_generate(self): + class CustomBeamSearchScorer(BeamSearchScorer): + finalize_called = False + process_called = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def process(self, *args, **kwargs): + results = super().process(*args, **kwargs) + CustomBeamSearchScorer.process_called = True + return results + + def finalize(self, *args, **kwargs): + results = super().finalize(*args, **kwargs) + CustomBeamSearchScorer.finalize_called = True + return results + + for model_class in self.all_generative_model_classes: + CustomBeamSearchScorer.process_called = False + CustomBeamSearchScorer.finalize_called = False + + config, input_ids, attention_mask = self._get_input_ids_and_config() + + model = model_class(config).to(torch_device).eval() + + logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( + input_ids.shape[-1], + config.forced_bos_token_id, + config.forced_eos_token_id, + ) + beam_kwargs = self._get_beam_kwargs() + + model_kwargs = {"attention_mask": attention_mask} + output_generate = model.generate( + input_ids, + do_sample=False, + max_new_tokens=self.max_new_tokens, + beam_search_scorer_class=CustomBeamSearchScorer, + output_scores=False, + output_logits=False, + output_attentions=False, + output_hidden_states=False, + return_dict_in_generate=False, + **beam_kwargs, + **logits_process_kwargs, + **model_kwargs, + ) + + self.assertIsNotNone(output_generate) + self.assertTrue(CustomBeamSearchScorer.process_called) + self.assertTrue(CustomBeamSearchScorer.finalize_called) + def test_beam_search_generate_dict_output(self): for model_class in self.all_generative_model_classes: config, input_ids, attention_mask = self._get_input_ids_and_config() @@ -2277,6 +2331,60 @@ def test_beam_search_example_integration(self): self.assertListEqual(outputs, ["Wie alt bist du?"]) + @slow + def test_beam_search_custom_scorer(self): + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base") + + encoder_input_str = "translate English to German: How old are you?" + encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids + + class CustomBeamSearchScorer(BeamSearchScorer): + finalize_called = False + process_called = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def process(self, *args, **kwargs): + results = super().process(*args, **kwargs) + CustomBeamSearchScorer.process_called = True + return results + + def finalize(self, *args, **kwargs): + results = super().finalize(*args, **kwargs) + CustomBeamSearchScorer.finalize_called = True + return results + + # lets run beam search using 3 beams + num_beams = 3 + # define decoder start token ids + input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long) + input_ids = input_ids * model.config.decoder_start_token_id + + # add encoder_outputs to model keyword arguments + model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)} + + outputs = model.generate( + input_ids, + num_beams=num_beams, + min_length=5, + eos_token_id=model.config.eos_token_id, + beam_search_scorer_class=CustomBeamSearchScorer, + **model_kwargs, + ) + outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + self.assertListEqual(outputs, ["Wie alt bist du?"]) + self.assertTrue( + CustomBeamSearchScorer.process_called, + "The `process` function of the custom beam search scorer was not called", + ) + self.assertTrue( + CustomBeamSearchScorer.finalize_called, + "The `finalize` function of the custom beam search scorer was not called", + ) + @slow def test_constrained_beam_search(self): # PT-only test: TF doesn't have constrained beam search From 5437234b2be625607382bda05198d413cda0c260 Mon Sep 17 00:00:00 2001 From: GM07 Date: Mon, 22 Jul 2024 14:27:09 -0400 Subject: [PATCH 2/2] Adding support to pass arguments to custom beam search scorer --- .../generation/configuration_utils.py | 9 +++++++++ src/transformers/generation/utils.py | 13 ++++++++++++- tests/generation/test_utils.py | 16 ++++++++++++++-- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 77899ef6a47e..c3d4c9db87fa 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -136,6 +136,9 @@ class GenerationConfig(PushToHubMixin): beam_search_scorer_class: (`class`, *optional*, defaults to `None`): Which class to use as a beam search scorer. If `None`, it will use the default `BeamSearchScorer` class. The type must inherit from `BeamSearchScorer`. + beam_search_scorer_args: (`dict`, *optional*, defaults to `None`) + Arguments that will be passed when creating the beam search scorer. When this argument is specified, + `beam_search_scorer_class` must not be `None`. > Parameters for manipulation of the model output logits @@ -357,6 +360,7 @@ def __init__(self, **kwargs): self.penalty_alpha = kwargs.pop("penalty_alpha", None) self.use_cache = kwargs.pop("use_cache", True) self.beam_search_scorer_class = kwargs.pop("beam_search_scorer_class", None) + self.beam_search_scorer_args = kwargs.pop("beam_search_scorer_args", None) # Parameters for manipulation of the model output logits self.temperature = kwargs.pop("temperature", 1.0) @@ -652,6 +656,11 @@ def validate(self, is_init=False): UserWarning, ) + if self.beam_search_scorer_class is None and self.beam_search_scorer_args is not None: + raise ValueError( + "The initialization arguments for the beam search scorer class were provided, but the class was not", + ) + # 3. detect incorrect paramaterization specific to advanced beam modes else: # constrained beam search diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 05bad22b8b89..9ad6ce13679b 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2004,7 +2004,6 @@ def generate( if generation_config.do_sample else None ) - # 12. prepare beam search scorer if generation_config.beam_search_scorer_class is None: beam_scorer = BeamSearchScorer( @@ -2017,6 +2016,11 @@ def generate( max_length=generation_config.max_length, ) else: + args = ( + generation_config.beam_search_scorer_args + if generation_config.beam_search_scorer_args is not None + else {} + ) beam_scorer = generation_config.beam_search_scorer_class( batch_size=batch_size, num_beams=generation_config.num_beams, @@ -2026,6 +2030,7 @@ def generate( num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_groups=generation_config.num_beam_groups, max_length=generation_config.max_length, + **args, ) # 13. interleave input_ids with `num_beams` additional sequences per batch @@ -2061,6 +2066,11 @@ def generate( max_length=generation_config.max_length, ) else: + args = ( + generation_config.beam_search_scorer_args + if generation_config.beam_search_scorer_args is not None + else {} + ) beam_scorer = generation_config.beam_search_scorer_class( batch_size=batch_size, num_beams=generation_config.num_beams, @@ -2070,6 +2080,7 @@ def generate( num_beam_hyps_to_keep=generation_config.num_return_sequences, num_beam_groups=generation_config.num_beam_groups, max_length=generation_config.max_length, + **args, ) # 12. interleave input_ids with `num_beams` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 074ffeffee64..80336d607826 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -586,8 +586,13 @@ class CustomBeamSearchScorer(BeamSearchScorer): finalize_called = False process_called = False - def __init__(self, *args, **kwargs): + def __init__(inside_self, test_args, *args, **kwargs): super().__init__(*args, **kwargs) + inside_self.test_args = test_args + self.assertTrue( + inside_self.test_args, + "The argument `test_args` should " "have been passed to the beam search scorer init function", + ) def process(self, *args, **kwargs): results = super().process(*args, **kwargs) @@ -620,6 +625,7 @@ def finalize(self, *args, **kwargs): do_sample=False, max_new_tokens=self.max_new_tokens, beam_search_scorer_class=CustomBeamSearchScorer, + beam_search_scorer_args={"test_args": True}, output_scores=False, output_logits=False, output_attentions=False, @@ -2343,8 +2349,13 @@ class CustomBeamSearchScorer(BeamSearchScorer): finalize_called = False process_called = False - def __init__(self, *args, **kwargs): + def __init__(inside_self, test_args, *args, **kwargs): super().__init__(*args, **kwargs) + inside_self.test_args = test_args + self.assertTrue( + inside_self.test_args, + ("The argument `test_args` should " "have been passed to the beam search scorer init function"), + ) def process(self, *args, **kwargs): results = super().process(*args, **kwargs) @@ -2371,6 +2382,7 @@ def finalize(self, *args, **kwargs): min_length=5, eos_token_id=model.config.eos_token_id, beam_search_scorer_class=CustomBeamSearchScorer, + beam_search_scorer_args={"test_args": True}, **model_kwargs, ) outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)