Skip to content

Commit

Permalink
Fix beam_scores shape when token scores shape changes after `logits…
Browse files Browse the repository at this point in the history
…_processor` (huggingface#25980)
  • Loading branch information
BakerBunker authored and parambharat committed Sep 26, 2023
1 parent b48ccf9 commit d04bd5c
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3038,7 +3038,9 @@ def beam_search(
) # (batch_size * num_beams, vocab_size)

next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)

# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
Expand Down Expand Up @@ -3363,7 +3365,9 @@ def beam_sample(
) # (batch_size * num_beams, vocab_size)

next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
Expand Down Expand Up @@ -4080,7 +4084,9 @@ def constrained_beam_search(

next_token_scores_processed = logits_processor(input_ids, next_token_scores)

next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)

scores_for_all_vocab = next_token_scores.clone()

Expand Down

0 comments on commit d04bd5c

Please sign in to comment.