Skip to content

Commit 5437234

Browse files
committed
Adding support to pass arguments to custom beam search scorer
1 parent 5351f0f commit 5437234

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

src/transformers/generation/configuration_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ class GenerationConfig(PushToHubMixin):
136136
beam_search_scorer_class: (`class`, *optional*, defaults to `None`):
137137
Which class to use as a beam search scorer. If `None`, it will use the default `BeamSearchScorer` class.
138138
The type must inherit from `BeamSearchScorer`.
139+
beam_search_scorer_args: (`dict`, *optional*, defaults to `None`)
140+
Arguments that will be passed when creating the beam search scorer. When this argument is specified,
141+
`beam_search_scorer_class` must not be `None`.
139142
140143
> Parameters for manipulation of the model output logits
141144
@@ -357,6 +360,7 @@ def __init__(self, **kwargs):
357360
self.penalty_alpha = kwargs.pop("penalty_alpha", None)
358361
self.use_cache = kwargs.pop("use_cache", True)
359362
self.beam_search_scorer_class = kwargs.pop("beam_search_scorer_class", None)
363+
self.beam_search_scorer_args = kwargs.pop("beam_search_scorer_args", None)
360364

361365
# Parameters for manipulation of the model output logits
362366
self.temperature = kwargs.pop("temperature", 1.0)
@@ -652,6 +656,11 @@ def validate(self, is_init=False):
652656
UserWarning,
653657
)
654658

659+
if self.beam_search_scorer_class is None and self.beam_search_scorer_args is not None:
660+
raise ValueError(
661+
"The initialization arguments for the beam search scorer class were provided, but the class was not",
662+
)
663+
655664
# 3. detect incorrect paramaterization specific to advanced beam modes
656665
else:
657666
# constrained beam search

src/transformers/generation/utils.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -2004,7 +2004,6 @@ def generate(
20042004
if generation_config.do_sample
20052005
else None
20062006
)
2007-
20082007
# 12. prepare beam search scorer
20092008
if generation_config.beam_search_scorer_class is None:
20102009
beam_scorer = BeamSearchScorer(
@@ -2017,6 +2016,11 @@ def generate(
20172016
max_length=generation_config.max_length,
20182017
)
20192018
else:
2019+
args = (
2020+
generation_config.beam_search_scorer_args
2021+
if generation_config.beam_search_scorer_args is not None
2022+
else {}
2023+
)
20202024
beam_scorer = generation_config.beam_search_scorer_class(
20212025
batch_size=batch_size,
20222026
num_beams=generation_config.num_beams,
@@ -2026,6 +2030,7 @@ def generate(
20262030
num_beam_hyps_to_keep=generation_config.num_return_sequences,
20272031
num_beam_groups=generation_config.num_beam_groups,
20282032
max_length=generation_config.max_length,
2033+
**args,
20292034
)
20302035

20312036
# 13. interleave input_ids with `num_beams` additional sequences per batch
@@ -2061,6 +2066,11 @@ def generate(
20612066
max_length=generation_config.max_length,
20622067
)
20632068
else:
2069+
args = (
2070+
generation_config.beam_search_scorer_args
2071+
if generation_config.beam_search_scorer_args is not None
2072+
else {}
2073+
)
20642074
beam_scorer = generation_config.beam_search_scorer_class(
20652075
batch_size=batch_size,
20662076
num_beams=generation_config.num_beams,
@@ -2070,6 +2080,7 @@ def generate(
20702080
num_beam_hyps_to_keep=generation_config.num_return_sequences,
20712081
num_beam_groups=generation_config.num_beam_groups,
20722082
max_length=generation_config.max_length,
2083+
**args,
20732084
)
20742085
# 12. interleave input_ids with `num_beams` additional sequences per batch
20752086
input_ids, model_kwargs = self._expand_inputs_for_generation(

tests/generation/test_utils.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,13 @@ class CustomBeamSearchScorer(BeamSearchScorer):
586586
finalize_called = False
587587
process_called = False
588588

589-
def __init__(self, *args, **kwargs):
589+
def __init__(inside_self, test_args, *args, **kwargs):
590590
super().__init__(*args, **kwargs)
591+
inside_self.test_args = test_args
592+
self.assertTrue(
593+
inside_self.test_args,
594+
"The argument `test_args` should " "have been passed to the beam search scorer init function",
595+
)
591596

592597
def process(self, *args, **kwargs):
593598
results = super().process(*args, **kwargs)
@@ -620,6 +625,7 @@ def finalize(self, *args, **kwargs):
620625
do_sample=False,
621626
max_new_tokens=self.max_new_tokens,
622627
beam_search_scorer_class=CustomBeamSearchScorer,
628+
beam_search_scorer_args={"test_args": True},
623629
output_scores=False,
624630
output_logits=False,
625631
output_attentions=False,
@@ -2343,8 +2349,13 @@ class CustomBeamSearchScorer(BeamSearchScorer):
23432349
finalize_called = False
23442350
process_called = False
23452351

2346-
def __init__(self, *args, **kwargs):
2352+
def __init__(inside_self, test_args, *args, **kwargs):
23472353
super().__init__(*args, **kwargs)
2354+
inside_self.test_args = test_args
2355+
self.assertTrue(
2356+
inside_self.test_args,
2357+
("The argument `test_args` should " "have been passed to the beam search scorer init function"),
2358+
)
23482359

23492360
def process(self, *args, **kwargs):
23502361
results = super().process(*args, **kwargs)
@@ -2371,6 +2382,7 @@ def finalize(self, *args, **kwargs):
23712382
min_length=5,
23722383
eos_token_id=model.config.eos_token_id,
23732384
beam_search_scorer_class=CustomBeamSearchScorer,
2385+
beam_search_scorer_args={"test_args": True},
23742386
**model_kwargs,
23752387
)
23762388
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

0 commit comments

Comments
 (0)