Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: get the correct beam index on eos token #18851

Merged
merged 2 commits into from
Sep 5, 2022

Conversation

gante
Copy link
Member

@gante gante commented Sep 1, 2022

What does this PR do?

Fixes #18839

We were not storing the correct beam index when an eos_token was generated (except for the first batch member), resulting in the issue linked above.


Confirming the change -- let's consider the following script, which gets the scores from output.sequences_scores and from model.compute_transition_beam_scores. Since there is no length penalty, the sum of the transition scores divided by the sequence length should match output.sequences_scores -- with the current codebase, it was not true except for the first batch.

from transformers import BartTokenizer, BartForConditionalGeneration

model_id = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_id)
model = BartForConditionalGeneration.from_pretrained(model_id)

input_tokens = ["what do you think it ? huggingface is a great library. And I enjoy it very much",
                "transformers is so good"]
batch_size = 2
num_beams = 10
max_length = 10
num_return_sequences = 5
input_ids = tokenizer(input_tokens, return_tensors='pt', padding=True).input_ids
output = model.generate(
    input_ids,
    max_length=max_length,
    num_beams=num_beams,
    num_return_sequences=num_return_sequences,
    return_dict_in_generate=True,
    output_scores=True
)
print("\nbeam indices:\n", output.beam_indices)
beam_lengths = (output.beam_indices != -1).sum(dim=1)
beam_scores = model.compute_transition_beam_scores(
    output.sequences, output.scores, output.beam_indices, tokenizer.eos_token_id
)
print("\nsequence scores (from outputs):\n", output.sequences_scores)
print("\nsequence scores (from compute_transition_beam_scores):\n", beam_scores.sum(dim=1) / beam_lengths)

🚫 output before this PR:

beam indices:
 tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  1, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  2, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  3, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  4, -1],
        [10, 10, 10, 10, 10, 10,  0, -1, -1, -1],
        [10, 10, 10, 10, 10, 10, 10,  0, -1, -1],
        [10, 10, 11, 11, 11, 11,  1, -1, -1, -1],
        [10, 10, 10, 10, 10, 10, 10,  1, -1, -1],
        [10, 10, 12, 12, 12, 12,  2, -1, -1, -1]])

sequence scores (from outputs):
 tensor([-2.4142e-02, -5.1596e-01, -5.2848e-01, -6.2190e-01, -6.2194e-01,
        -4.1643e-04, -1.0500e+00, -1.1113e+00, -1.1323e+00, -1.1955e+00])

sequence scores (from compute_transition_beam_scores):
 tensor([-0.0241, -0.5160, -0.5285, -0.6219, -0.6219, -2.4050, -2.5656, -3.4137,
        -2.3775, -3.5453])

✅ output after this PR:

beam indices:
 tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  1, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  2, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  3, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  4, -1],
        [10, 10, 10, 10, 10, 10, 10, -1, -1, -1],
        [10, 10, 10, 10, 10, 10, 10, 10, -1, -1],
        [10, 10, 11, 11, 11, 11, 11, -1, -1, -1],
        [10, 10, 10, 10, 10, 10, 10, 11, -1, -1],
        [10, 10, 12, 12, 12, 12, 12, -1, -1, -1]])

sequence scores (from outputs):
 tensor([-2.4142e-02, -5.1596e-01, -5.2848e-01, -6.2190e-01, -6.2194e-01,
        -4.1643e-04, -1.0500e+00, -1.1113e+00, -1.1323e+00, -1.1955e+00])

sequence scores (from compute_transition_beam_scores):
 tensor([-2.4142e-02, -5.1596e-01, -5.2848e-01, -6.2190e-01, -6.2194e-01,
        -4.1643e-04, -1.0500e+00, -1.1113e+00, -1.1323e+00, -1.1955e+00])

@@ -259,7 +259,7 @@ def process(
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (next_index,)
beam_index = beam_index + (batch_beam_idx,)
Copy link
Member Author

@gante gante Sep 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_beam_idx -> beam idx considering all batches (between num_beams*(batch_idx-1) and (num_beams*batch_idx)-1)

next_index -> beam idx for the current batch (between 0 and num_beams-1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yeah great point!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 1, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for diving into this very complex logic!

@gante gante merged commit d4dbd7c into huggingface:main Sep 5, 2022
@gante gante deleted the batched_beam_search_fix branch September 5, 2022 18:35
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG for beam_indices from model.generate()
3 participants