Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom beam search scorer argument in generate function #32097

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ class GenerationConfig(PushToHubMixin):
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
beam_search_scorer_class: (`class`, *optional*, defaults to `None`):
Which class to use as a beam search scorer. If `None`, it will use the default `BeamSearchScorer` class.
The type must inherit from `BeamSearchScorer`.
beam_search_scorer_args: (`dict`, *optional*, defaults to `None`)
Arguments that will be passed when creating the beam search scorer. When this argument is specified,
`beam_search_scorer_class` must not be `None`.

> Parameters for manipulation of the model output logits

Expand Down Expand Up @@ -353,6 +359,8 @@ def __init__(self, **kwargs):
self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
self.use_cache = kwargs.pop("use_cache", True)
self.beam_search_scorer_class = kwargs.pop("beam_search_scorer_class", None)
self.beam_search_scorer_args = kwargs.pop("beam_search_scorer_args", None)

# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)
Expand Down Expand Up @@ -640,6 +648,18 @@ def validate(self, is_init=False):
single_beam_wrong_parameter_msg.format(flag_name="constraints", flag_value=self.constraints),
UserWarning,
)
if self.beam_search_scorer_class is not None:
warnings.warn(
single_beam_wrong_parameter_msg.format(
flag_name="beam_search_scorer_class", flag_value=self.beam_search_scorer_class
),
UserWarning,
)

if self.beam_search_scorer_class is None and self.beam_search_scorer_args is not None:
raise ValueError(
"The initialization arguments for the beam search scorer class were provided, but the class was not",
)

# 3. detect incorrect paramaterization specific to advanced beam modes
else:
Expand All @@ -660,6 +680,12 @@ def validate(self, is_init=False):
flag_name="num_beam_groups", flag_value=self.num_beam_groups
)
)
if self.beam_search_scorer_class is not None:
raise ValueError(
constrained_wrong_parameter_msg.format(
flag_name="beam_search_scorer_class", flag_value=self.beam_search_scorer_class
)
)
# group beam search
if self.diversity_penalty != 0.0 or self.num_beam_groups != 1:
group_error_prefix = (
Expand Down
74 changes: 54 additions & 20 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2004,17 +2004,34 @@ def generate(
if generation_config.do_sample
else None
)

# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
if generation_config.beam_search_scorer_class is None:
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
else:
args = (
generation_config.beam_search_scorer_args
if generation_config.beam_search_scorer_args is not None
else {}
)
beam_scorer = generation_config.beam_search_scorer_class(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
num_beam_groups=generation_config.num_beam_groups,
max_length=generation_config.max_length,
**args,
)

# 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
Expand All @@ -2038,16 +2055,33 @@ def generate(

elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
# 11. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
num_beam_groups=generation_config.num_beam_groups,
max_length=generation_config.max_length,
)
if generation_config.beam_search_scorer_class is None:
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
else:
args = (
generation_config.beam_search_scorer_args
if generation_config.beam_search_scorer_args is not None
else {}
)
beam_scorer = generation_config.beam_search_scorer_class(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
num_beam_groups=generation_config.num_beam_groups,
max_length=generation_config.max_length,
**args,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
Expand Down
120 changes: 120 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput,
BeamSearchScorer,
DisjunctiveConstraint,
GenerateBeamDecoderOnlyOutput,
GenerateBeamEncoderDecoderOutput,
Expand Down Expand Up @@ -580,6 +581,65 @@ def test_beam_search_generate(self):
else:
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])

def test_custom_beam_search_scorer_generate(self):
class CustomBeamSearchScorer(BeamSearchScorer):
finalize_called = False
process_called = False

def __init__(inside_self, test_args, *args, **kwargs):
super().__init__(*args, **kwargs)
inside_self.test_args = test_args
self.assertTrue(
inside_self.test_args,
"The argument `test_args` should " "have been passed to the beam search scorer init function",
)

def process(self, *args, **kwargs):
results = super().process(*args, **kwargs)
CustomBeamSearchScorer.process_called = True
return results

def finalize(self, *args, **kwargs):
results = super().finalize(*args, **kwargs)
CustomBeamSearchScorer.finalize_called = True
return results

for model_class in self.all_generative_model_classes:
CustomBeamSearchScorer.process_called = False
CustomBeamSearchScorer.finalize_called = False

config, input_ids, attention_mask = self._get_input_ids_and_config()

model = model_class(config).to(torch_device).eval()

logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
input_ids.shape[-1],
config.forced_bos_token_id,
config.forced_eos_token_id,
)
beam_kwargs = self._get_beam_kwargs()

model_kwargs = {"attention_mask": attention_mask}
output_generate = model.generate(
input_ids,
do_sample=False,
max_new_tokens=self.max_new_tokens,
beam_search_scorer_class=CustomBeamSearchScorer,
beam_search_scorer_args={"test_args": True},
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
**beam_kwargs,
**logits_process_kwargs,
**model_kwargs,
)

self.assertIsNotNone(output_generate)
self.assertTrue(CustomBeamSearchScorer.process_called)
self.assertTrue(CustomBeamSearchScorer.finalize_called)

def test_beam_search_generate_dict_output(self):
for model_class in self.all_generative_model_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
Expand Down Expand Up @@ -2277,6 +2337,66 @@ def test_beam_search_example_integration(self):

self.assertListEqual(outputs, ["Wie alt bist du?"])

@slow
def test_beam_search_custom_scorer(self):
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")

encoder_input_str = "translate English to German: How old are you?"
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

class CustomBeamSearchScorer(BeamSearchScorer):
finalize_called = False
process_called = False

def __init__(inside_self, test_args, *args, **kwargs):
super().__init__(*args, **kwargs)
inside_self.test_args = test_args
self.assertTrue(
inside_self.test_args,
("The argument `test_args` should " "have been passed to the beam search scorer init function"),
)

def process(self, *args, **kwargs):
results = super().process(*args, **kwargs)
CustomBeamSearchScorer.process_called = True
return results

def finalize(self, *args, **kwargs):
results = super().finalize(*args, **kwargs)
CustomBeamSearchScorer.finalize_called = True
return results

# lets run beam search using 3 beams
num_beams = 3
# define decoder start token ids
input_ids = torch.ones((1, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id

# add encoder_outputs to model keyword arguments
model_kwargs = {"encoder_outputs": model.get_encoder()(encoder_input_ids, return_dict=True)}

outputs = model.generate(
input_ids,
num_beams=num_beams,
min_length=5,
eos_token_id=model.config.eos_token_id,
beam_search_scorer_class=CustomBeamSearchScorer,
beam_search_scorer_args={"test_args": True},
**model_kwargs,
)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

self.assertListEqual(outputs, ["Wie alt bist du?"])
self.assertTrue(
CustomBeamSearchScorer.process_called,
"The `process` function of the custom beam search scorer was not called",
)
self.assertTrue(
CustomBeamSearchScorer.finalize_called,
"The `finalize` function of the custom beam search scorer was not called",
)

@slow
def test_constrained_beam_search(self):
# PT-only test: TF doesn't have constrained beam search
Expand Down