-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bugfix/torchtext include lengths #2689
Bugfix/torchtext include lengths #2689
Conversation
…xt.data.Field configured as include_lengths=True
Hello @thschaaf! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-07-30 01:53:43 UTC |
@thschaaf Thanks for the PR. How much value do you see in this support? I was told by torchtext member that they will drop the Batch class from torchtext moving forward. |
@awaelchli Hopefully when torchtext removes the Batch class they do it without breaking too much code from people (in a substantial way). There is enough value to support this until torchtext actually does change their implementation. In my case the Batch object of torchtext is used behind the curtain, and it caused my Skip-Thought model training to fail on GPUs. I am sure others might run into similar issues, and was quite happy that the change to Pytorch-Lightning is quite compact. With this change RNN training works on a GPU. What do you think about the added tests? They are agnostic to the underlying Batch class, only making sure that the the torchtext.data.Field paramter include_length=True is tested. They might be useful in the future, even after the code dealing with the torchtext Batch class is removed. The change torchtext is planing seems sensible, but more people might use Pytorch-Lightning together with torchtext earlier. It is not intuitive if a model training runs on the CPU but throws such an exception on GPU. Of course having the PR changes become part of PL would make my Skip-Thought model train faster and my life easier. |
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
@Borda Thanks! For some reasons, which I don't remember exactly, I followed the pytorch-lightning project. Just to unfollow the project did the trick. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, thanks for the fix @thschaaf
This pull request is now in conflict... :( |
What does this PR do?
It fixes a bug when using torchtex and torchtext.data.Field with include_lengths=True that arises when transferring data to GPU.
It adds tests to check if Batches created by torchtext with include_lengths=True and include_lengths=False are processed by the Trainer.fit().
The fix checks if the data is a Tensor, tuple, or list before sending it to the device. If it is a tuple, or list, it iterates over the elements and sends them to the device. (Implementation changed and it uses now
move_data_to_device
recursively.)Fixes #2688
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃