-
Notifications
You must be signed in to change notification settings - Fork 525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pt: explicitly set device #3307
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Force requires setting device to `env.DEVICE` or `cpu` explictly for functions like `torch.tensor`, `torch.zeros`, `torch.ones`, `torch.rand`, `torch.eye`, and dataloader. This ensures that no OPs that should be run on GPUs run on CPUs. The trick here is `torch.set_default_device("cuda:9999999")` in the tests, so errors will be thrown if no device is set. Tips: (1) Avoid `torch.zeros(...).to(device=...)`. This firstly initlizes memory on CPUs and copy it to GPUs. (2) Use `with torch.device(...)` for third-party modules. Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #3307 +/- ##
=======================================
Coverage 75.29% 75.29%
=======================================
Files 398 398
Lines 33684 33694 +10
Branches 1604 1604
=======================================
+ Hits 25361 25371 +10
Misses 7462 7462
Partials 861 861 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
wanghan-iapcm
approved these changes
Feb 20, 2024
iProzd
reviewed
Feb 20, 2024
iProzd
reviewed
Feb 20, 2024
anyangml
approved these changes
Feb 20, 2024
iProzd
approved these changes
Feb 20, 2024
github-merge-queue
bot
removed this pull request from the merge queue due to a conflict with the base branch
Feb 21, 2024
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Closed
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Forcely requires setting the device to
env.DEVICE
orcpu
explicitly for functions liketorch.tensor
,torch.zeros
,torch.ones
,torch.rand
,torch.eye
,LayerNorm
, and data loader.This ensures that no OP runs on the wrong device. The trick here is
torch.set_default_device("cuda:9999999")
in the tests, so errors will be thrown if the default device is used.Tips:
(1) Avoid
torch.zeros(...).to(device=...)
. This first initializes memory on CPUs and copies it to GPUs.(2) Use
with torch.device(...)
for a module that cannot set a device (i.e., the data loader).