Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BlenderBot RuntimeError: CUDA error: device-side assert triggered #9046

Closed
manzar96 opened this issue Dec 11, 2020 · 5 comments · Fixed by #9131
Closed

BlenderBot RuntimeError: CUDA error: device-side assert triggered #9046

manzar96 opened this issue Dec 11, 2020 · 5 comments · Fixed by #9131
Assignees

Comments

@manzar96
Copy link

manzar96 commented Dec 11, 2020

Environment info

  • transformers version: 4.0.0
  • Platform: Linux-5.4.0-56-generic-x86_64-with-glibc2.29
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.7.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: yes (GTX 1060 6GB)
  • Using distributed or parallel set-up in script?: no

Who can help

@patrickvonplaten

Information

Model I am using (Bert, XLNet ...): I am using the BlenderbotForConditionalGeneration ('facebook/blenderbot-90M') along with the relevant small tokenizer.

The problem arises when using:

I am using my own trainer implementation. I think that the problem has to do with the indexes of the labels. More specifically when I am using:

outputs = self.model(input_ids=inputs, attention_mask=inputs_att, labels=pad_targets, return_dict=True)

everything works fine as the "pad_targets" are the targets using 0 as the index for masked (padded) tokens.
However when I am using:

outputs = self.model(input_ids=inputs, attention_mask=inputs_att, labels=repl_targets, return_dict=True)
and then printing the outputs['loss'] the following error is occurred:

RuntimeError: CUDA error: device-side assert triggered

as the "repl_targets" are the targets using the -100 as the index for masked (padded) tokens.

The aforementioned error also occurs when using the argument:
decoder_input_ads=repl_targets

The tasks I am working on is:
Dialogue generation in Empathetic Dialogues dataset.

Expected behavior

I think that there is a problem with the -100 padding token. But I am not sure :)

@patrickvonplaten
Copy link
Contributor

Hey @manzar96,

It would be awesome if you could provide a full code snippet that I can copy paste and run to reproduce the error. I am not able to do so with your code above.

Thanks a lot!

@manzar96
Copy link
Author

manzar96 commented Dec 11, 2020

I made an example:

from transformers import BlenderbotSmallTokenizer, \
    BlenderbotForConditionalGeneration

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = BlenderbotForConditionalGeneration.from_pretrained('facebook/blenderbot-90M')
model.to(DEVICE)
inputs = torch.tensor([[14, 49, 42, 626, 2727, 1063, 5, 0, 0, 0, 0, 0, 0, 0],
                       [14, 1322, 7, 1427, 13, 7, 153, 384, 5, 14,
                        18,   64, 7261,    5]], device=DEVICE)

inputs_att = torch.tensor([[1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],
       device=DEVICE)

repl_targets = torch.tensor([[  46,   15, 3283,   20, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100],
        [ 121,   54,   37,   53,   60,   12,  447,   10, 1427,   15,   51,   11,
          598,   20]], device=DEVICE)

pad_targets = torch.tensor([[  46,   15, 3283,   20,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0],
        [ 121,   54,   37,   53,   60,   12,  447,   10, 1427,   15,   51,   11,
          598,   20]], device=DEVICE)


outputs=model.forward(input_ids=inputs, attention_mask=inputs_att,
                             labels=repl_targets, return_dict=True)
import ipdb;ipdb.set_trace()

If you try printing the outputs['loss'] the error occurs. However, if you replace the repl_targets with the pad_targets variable everything works fine (but the loss does not mask 0, so that's not always correct for use).

@patil-suraj
Copy link
Contributor

@patrickvonplaten

This is a bug, in bart decoder_input_ids are prepared by shifting the labels to right, but it doesn't replace -100 with pad_token_id.

def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
"""
Shift input ids one token to the right, and wrap the last non pad token (usually <eos>).
"""
prev_output_tokens = input_ids.clone()
index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
prev_output_tokens[:, 1:] = input_ids[:, :-1]
return prev_output_tokens

In T5 we automatically replace -100 with pad_token_id when preparing decoder_input_ids.

def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
assert (
decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

@patil-suraj patil-suraj self-assigned this Dec 14, 2020
@patrickvonplaten
Copy link
Contributor

You're right @patil-suraj - do you want to open a PR to fix it in Bart? :-)

@patil-suraj
Copy link
Contributor

Yeah!

@patrickvonplaten patrickvonplaten linked a pull request Dec 15, 2020 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants