Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug when combining grouped beam search and constrained prefix decoding #10415

Closed
mnschmit opened this issue Feb 26, 2021 · 1 comment · Fixed by #10475
Closed

Bug when combining grouped beam search and constrained prefix decoding #10415

mnschmit opened this issue Feb 26, 2021 · 1 comment · Fixed by #10475

Comments

@mnschmit
Copy link
Contributor

mnschmit commented Feb 26, 2021

Environment info

  • transformers version: 4.3.3
  • Platform: Linux-5.8.0-38-generic-x86_64-with-glibc2.29
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.5.1+cu101 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help

@patrickvonplaten

Information

Model I am using (Bert, XLNet ...): T5

The problem arises when using: my own modified scripts

To reproduce

Steps to reproduce the behavior: run this simple script

from transformers import T5TokenizerFast, T5ForConditionalGeneration

tokenizer = T5TokenizerFast.from_pretrained('t5-small')

inp = 'The <extra_id_0> walks in <extra_id_1> park'
enc_inp = tokenizer(inp, return_tensors='pt')
model = T5ForConditionalGeneration.from_pretrained('t5-small')


def prefix_allowed_tokens_fn(batch_id, input_ids):
    return [2]  # dummy value


out = model.generate(
    **enc_inp,
    num_beams=2,
    num_beam_groups=2,
    diversity_penalty=0.2,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn
)

This produces the following error:

Traceback (most recent call last):
  File "debugging/grouped_beam_search.py", line 14, in <module>
    out = model.generate(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
    return func(*args, **kwargs)
  File "/mounts/Users/student/martin/.local/lib/python3.8/site-packages/transformers/generation_utils.py", line 1041, in generate
    return self.group_beam_search(
  File "/mounts/Users/student/martin/.local/lib/python3.8/site-packages/transformers/generation_utils.py", line 2161, in group_beam_search
    next_token_scores = logits_processor(
  File "/mounts/Users/student/martin/.local/lib/python3.8/site-packages/transformers/generation_logits_process.py", line 89, in __call__
    scores = processor(input_ids, scores)
  File "/mounts/Users/student/martin/.local/lib/python3.8/site-packages/transformers/generation_logits_process.py", line 458, in __call__
    for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
RuntimeError: shape '[-1, 2, 1]' is invalid for input of size 1

Expected behavior

No error.

As far as I can tell, the PrefixConstrainedLogitsProcessor still receives the original number of beams even when grouped beam search is used. But it should be the number of subbeams. So replacing num_beams with num_beams // num_beam_groups in the constructor of PrefixConstrainedLogitsProcessor in method _get_logits_processor in file generation_utils.py should fix it.
What do you think?

@patrickvonplaten
Copy link
Contributor

Hey @mnschmit,

thanks for your bug report! Yes, you're right -> I think we should indeed replace num_beams by num_beams // num_beam_groups. Do you want to open a PR to fix it? :-) Otherwise, I can do it as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants