-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Configure GPU index for 'astra_cuda', select GPU currently used by PyTorch in OperatorModule #1546
base: master
Are you sure you want to change the base?
Conversation
Hey @jleuschn, thanks a lot for your effort! From a coding point of view, it looks very good. However, I'm afraid that this solution is too hacky and not future-proof. That's mostly due to circumstances and not your fault. Let me give my reasoning here.
Okay, this has become way longer than I expected, and I realize it will not be a trivial change. I will look into it tonight make a suggestion. But I have no way to test with multiple GPUs, so you would have to help me out @jleuschn. Does that sound okay? |
Thanks, @kohr-h , for checking the request and pointing out the issues above! While parameters and buffers are broadcasted to the GPUs in PyTorch's Yes, i can help out testing on multiple GPUs! |
Good point about the replication thing! Hm, so the first call to Regarding the larger question of whether it's worth the effort. Currently I have my doubts. The whole thing is quite inefficient anyway since each ray transform does the whole roundtrip CPU->GPU->CPU no matter what, and if it's wrapped into an |
Yes, i ran some speed test. It seems that it only makes a difference if the GPUs are heavily used already by the rest of the network. TBH, i don't fully understand why, considering the mentioned chain, maybe the chains are not in sync between the different GPUs, so some ray trafo runs in parallel to another layer? |
Very good, thanks for doing the speed test! Indeed, the gain is not nothing, but certainly not what you would hope for when throwing N times the compute power at the problem. So I agree, for now it's not necessary to invest time, but it's good to know about this limitation and that we need to think about solutions at some point. |
This pull request implements two new features:
'gpu_index'
inRayTransformBase
OperatorModule
by setting'gpu_index'
The second feature feels a little hacky, since it assumes the special role of the
'gpu_index'
property if existing for anyOperator
instance that is wrapped by theOperatorModule
. However this seems to me to be the most non-invasive way to implement this behaviour, since otherwise probablyRayTransformBase
would have to know about torch, e.g. offeringgpu_index = 'torch_current'
.