Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BART FP16 #3117

Closed
astariul opened this issue Mar 4, 2020 · 8 comments · Fixed by #3145
Closed

BART FP16 #3117

astariul opened this issue Mar 4, 2020 · 8 comments · Fixed by #3145
Assignees

Comments

@astariul
Copy link
Contributor

astariul commented Mar 4, 2020

🚀 Feature request

I would like to use BART in FP16 mode, but it seems impossible for now :

config = BartConfig(vocab_size=50264, output_past=True)
model = AutoModelWithLMHead.from_pretrained('bart-large-cnn', config=config).cuda().half()
tokenizer = AutoTokenizer.from_pretrained('bart-large-cnn')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
generated_ids = model.generate(inputs['input_ids'].cuda(), attention_mask=inputs['attention_mask'].cuda(), num_beams=4, max_length=5)

File "/data/user/.venv/bartqg/lib/python3.6/site-packages/transformers/modeling_bart.py", line 647, in forward
attn_output = torch.bmm(attn_probs, v)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #2 'mat2' in call to _th_bmm

@sshleifer Do you plan to implement a FP16-friendly version of BART ?

@sshleifer sshleifer self-assigned this Mar 4, 2020
@sshleifer
Copy link
Contributor

Not on my roadmap just yet, but I would definitely consider it if there were lots of demand. Since we only have inference code right now, the benefit seems marginal.

@astariul
Copy link
Contributor Author

astariul commented Mar 5, 2020

@BramVanroy Should this issue be closed ?

FP16 is not implemented yet. And the wontfix label is clear.

Keeping the issue open may make it easier for people to find it and show their potential interest in FP16.

@thomwolf
Copy link
Member

thomwolf commented Mar 5, 2020

This should not be closed indeed.

@sshleifer, we intend all the models to be compatible with FP16, this is the direction the field is going and with the Volta-level GPU being widespread now, there is less and less reason not to use mixed-precision fine-tuning (half memory and significantly faster).

@thomwolf thomwolf reopened this Mar 5, 2020
@stale stale bot removed the wontfix label Mar 5, 2020
@thomwolf
Copy link
Member

thomwolf commented Mar 5, 2020

This can probably be fixed by changing the torch.float32 casting here to a cast to the type of attn_weights like it's done in the original fairseq code here.

Do you mind fixing this and testing the failing script posted in the issue @sshleifer?

@sshleifer
Copy link
Contributor

Yep, on it!

@easonnie
Copy link
Contributor

easonnie commented Mar 5, 2020

Hi, @sshleifer. Thank you so much for your effort on BART. I encountered the same fp16 issues today. The current BART code can be trained (without fp16) using the run_glue script in: https://github.com/huggingface/transformers/blob/master/examples/run_glue.py
So, it will be really nice if the fp16 training can also work out.

@sshleifer sshleifer linked a pull request Mar 5, 2020 that will close this issue
@BramVanroy
Copy link
Collaborator

My bad, I thought @sshleifer's labeling was a note that he isn't planning to change anything wontfix, so no future updates would be possible and then I closed it. Will keep that in mind for the future.

@thomwolf
Copy link
Member

thomwolf commented Mar 6, 2020

No bad

@sshleifer for the moment, please ping me with DM before adding "wontfix" labels to issues, thanks.

@AOZMH AOZMH mentioned this issue Mar 12, 2020
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants