Skip to content

Commit

Permalink
fix beam search with static cache
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Aug 30, 2024
1 parent 2b47ea9 commit 4dd1494
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ def _prepare_cache_for_generation(
)
model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
max_cache_len=max_cache_length,
device=device,
model_kwargs=model_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
signature = inspect.signature(model.forward).parameters.keys()

# no cache as some models require special cache classes to be init outside forward
model.genertation_config.use_cache = False
model.generation_config.use_cache = False

# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
Expand Down

0 comments on commit 4dd1494

Please sign in to comment.