Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added the max_matching_ngram_size to GenerationConfig #29131

Merged
merged 7 commits into from
Mar 6, 2024
Merged
4 changes: 2 additions & 2 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
def __init__(
self,
num_output_tokens: int = 10,
max_matching_ngram_size: int = 2,
max_matching_ngram_size: int = None,
):
self.num_output_tokens = num_output_tokens
self.max_matching_ngram_size = max_matching_ngram_size
self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2

if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
Expand Down
1 change: 1 addition & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def __init__(self, **kwargs):

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None)

# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
Expand Down
1 change: 1 addition & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def _get_candidate_generator(
if generation_config.prompt_lookup_num_tokens is not None:
candidate_generator = PromptLookupCandidateGenerator(
num_output_tokens=generation_config.prompt_lookup_num_tokens,
max_matching_ngram_size=generation_config.max_matching_ngram_size
)
else:
candidate_generator = AssistedCandidateGenerator(
Expand Down