-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't force .cpu()
on all PyTorch outputs
#1052
Comments
.cpu()
on all PyTorch outputs
I would really like to work on this if possible, it burned me a few times |
Of course :) |
@ricardoV94 ; i'm wondering if this has overlap with the issue I found when messing around with pymc + py[torch|tensor]: #1065. I guess what I'm wondering is should the linker be smart enough to know when to do result.detach().numpy() Then the issue with pymc should un theory be solved |
The problem is that fails when the data is on the gpu. Is there a cheap way to know when it is and whet it's not? Just wrap it in a try/except? |
Yea, |
Wanna try that? It's still suboptimal to always force transfer but probably fine for a rough use of the backend. We may allow user control with custom linker settings in the future |
We would combine this with the suggestion you had earlier as well?
|
Let's skip that idea of the updates for now and force everything to be numpy once it's out. Otherwise you'll have the same sort of problems you saw in your PyMC tests |
This whole thing (i.e., calling
out.cpu()
) is suboptimal. I think we don't need it for JAX (which returns JAX arrays/ not numpy arrays), becausenp.asarray
works with it, and I guess it doesn't work for torch tensors.pytensor/pytensor/link/pytorch/linker.py
Line 16 in 7b13a95
This should only be needed for updated shared variables where we have to convert to a common type as they could be used in multiple functions with distinct backends.
Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by
Function
.pytensor/pytensor/compile/function/types.py
Lines 1009 to 1017 in 7b13a95
Originally posted by @ricardoV94 in #1032 (comment)
The text was updated successfully, but these errors were encountered: