Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent BatchEncoding from blindly passing casts down to the tensors it contains #8860

Merged
merged 1 commit into from
Dec 1, 2020

Conversation

Craigacp
Copy link
Contributor

@Craigacp Craigacp commented Nov 30, 2020

What does this PR do?

This PR prevents BatchEncoding.to from passing down things which aren't devices to the tensors it contains. Previously it would pass down all the arguments, and as the to method in pytorch can also cast the arguments to different types it's used blindly by other packages (e.g. Nvidia's Apex). This caused an issue where when using Apex's AMP support with O2 or greater it would cast the token indexes from a LongTensor to a HalfTensor truncating our vocab at 65k and rounding most of the words to the nearest 8th word (if you blindly insert the cast back in in the embedding layer, which the warning says to do).

The doc for BatchEncoding.to says it is only for moving the encoding and the tensors it contains between devices, but as the type checking isn't on by default it can behave like a regular pytorch to method and accept cast arguments that it passes down to the tensors it contains.

Fixes #6582

Before submitting

There are no docs or tests changes as the change makes the method conform with its currently documented behaviour.

@LysandreJik

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! We prefer using f-strings so I suggested a change. Thank you for opening the PR.

src/transformers/tokenization_utils_base.py Outdated Show resolved Hide resolved
…it contains. Fixes huggingface#6582.

Update src/transformers/tokenization_utils_base.py with review fix

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
@Craigacp Craigacp force-pushed the batch-encoding-cast-check branch from d35f6f6 to 2a33400 Compare December 1, 2020 17:21
@Craigacp
Copy link
Contributor Author

Craigacp commented Dec 1, 2020

black complained about the style after the update, so I fixed it and squashed the commits again.

@LysandreJik
Copy link
Member

Thank you @Craigacp!

@LysandreJik LysandreJik merged commit 9c18f15 into huggingface:master Dec 1, 2020
@Craigacp Craigacp deleted the batch-encoding-cast-check branch December 1, 2020 18:59
stas00 pushed a commit to stas00/transformers that referenced this pull request Dec 5, 2020
…it contains. Fixes huggingface#6582. (huggingface#8860)

Update src/transformers/tokenization_utils_base.py with review fix

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BatchEncoding interacts poorly with apex.amp
2 participants