Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beam search fails when using model parallelism #9200

Closed
1 of 4 tasks
TobiasNorlund opened this issue Dec 18, 2020 · 3 comments
Closed
1 of 4 tasks

Beam search fails when using model parallelism #9200

TobiasNorlund opened this issue Dec 18, 2020 · 3 comments

Comments

@TobiasNorlund
Copy link
Contributor

Environment info

  • transformers version: 4.1.1
  • Platform: Linux-4.4.0-194-generic-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.7.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: Yes, two GTX 1080, on a single node
  • Using distributed or parallel set-up in script?: Using model parallelism through model.parallelize()

Who can help

@LysandreJik
@alexorona

Information

Model I am using (Bert, XLNet ...): GPT2

The problem arises when using:

  • the official example scripts:
  • my own modified scripts:

The tasks I am working on is:

  • an official GLUE/SQUaD task:
  • my own task or dataset:

To reproduce

The recent (and awesome!) model parallelize() doesn't seem to work with beam search decoding at the moment. The behavior can be reproduced on the official huggingface/transformers-pytorch-gpu:4.1.1 docker image by running the following (on a machine with multiple GPUs):

import transformers

tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")
model = transformers.GPT2LMHeadModel.from_pretrained("gpt2")
model.parallelize()

input_ids = tokenizer.encode("This is a test", return_tensors="pt").to("cuda:0")
model.generate(input_ids, num_beams=2)

This raises the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py", line 612, in generate
    **model_kwargs,
  File "/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py", line 1088, in beam_search
    model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
  File "/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py", line 229, in _reorder_cache
    return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
  File "/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py", line 229, in <genexpr>
    return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
RuntimeError: Input, output and indices must be on the current device

Expected behavior

The expected behavior is to not raise an error, but instead correctly return the beam search decoding.

@TobiasNorlund
Copy link
Contributor Author

As the trace suggests, the error seem to come from the _reorder_cache method in generation_utils.py. Since the model is parallelized among multiple devices, it fails since the device of beam_idx and layer_past don't match for all layers.

I just tried to modify line 229 in generation_utils.py to:

return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)

which seems to work.
I'm happy to file a PR with this change if you approve. Please let me know if there is anything I should be aware of, or pay extra attention to.

@OyvindTafjord
Copy link
Contributor

OyvindTafjord commented May 12, 2021

FWIW, this fix doesn't currently work for T5, as the fix to _reorder_cache is not reflected in the modeling_t5.py file. Following the above, changing this line to layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), appears to fix it.

@patrickvonplaten

@patrickvonplaten
Copy link
Contributor

@OyvindTafjord - would you mind opening a new PR for it? :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants