Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using FP16 on BartModel #3249

Closed
2 of 4 tasks
AOZMH opened this issue Mar 12, 2020 · 3 comments · Fixed by #3266
Closed
2 of 4 tasks

Using FP16 on BartModel #3249

AOZMH opened this issue Mar 12, 2020 · 3 comments · Fixed by #3266
Assignees

Comments

@AOZMH
Copy link

AOZMH commented Mar 12, 2020

🐛 Bug

Information

Model I am using (Bert, XLNet ...): BART

Language I am using the model on (English, Chinese ...): English

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: CNN/DM
  • my own task or dataset: (give details below)

To reproduce

I've installed the master branch of transformers but I still encountered the same issue as #3117 when using FP16 BartModel. I just initialized the model without loading the pretarined weights, but I guess the model should still be able to correctly forward the input LongTensor(batch, seq_length). The code is shown below, simply initialize a model and forward an input:

model = BartModel(BartConfig())
model = model.cuda().half()
cur_inputs = torch.zeros(4,16,dtype=torch.long).cuda()
cur_res = model(cur_inputs)

The error is:

~\Anaconda3\envs\pytorch\lib\site-packages\transformers\modeling_bart.py in forward(self, query, key, value, key_padding_mask, layer_state, need_weights, static_kv, attn_mask)
assert v is not None
--> attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'mat2' in call to _th_bmm

@sshleifer The model is quite novel to me, so am I using it incorrectly or there's still a bug in BertModel class? Thanks in advance for the help!

Environment info

  • transformers version: master branch
  • Platform: Windows
  • Python version: 3.7.0
  • PyTorch version (GPU?): 1.4.0
  • Tensorflow version (GPU?): /
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No
@sshleifer sshleifer self-assigned this Mar 12, 2020
@AOZMH
Copy link
Author

AOZMH commented Mar 13, 2020

@sshleifer May I ask could you reproduce the error in your machine? I ran the same code on a Linux machine with master-branch of transformers, but still got the same error. I'm planning to use BartModel these days so please notify me at your earliest convenience if there're any updates. Many thanks!

@sshleifer
Copy link
Contributor

Yes, will try to fix it today! Thanks for reporting!

@sshleifer sshleifer linked a pull request Mar 13, 2020 that will close this issue
@AOZMH
Copy link
Author

AOZMH commented Mar 14, 2020

Yes, will try to fix it today! Thanks for reporting!

Thanks Sam,

The code works well this time! Thanks again for the contribution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants