Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't force .cpu() on all PyTorch outputs #1052

Open
ricardoV94 opened this issue Oct 29, 2024 · 8 comments
Open

Don't force .cpu() on all PyTorch outputs #1052

ricardoV94 opened this issue Oct 29, 2024 · 8 comments
Labels
backend compatibility enhancement New feature or request torch PyTorch backend

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 29, 2024

This whole thing (i.e., calling out.cpu()) is suboptimal. I think we don't need it for JAX (which returns JAX arrays/ not numpy arrays), because np.asarray works with it, and I guess it doesn't work for torch tensors.

This should only be needed for updated shared variables where we have to convert to a common type as they could be used in multiple functions with distinct backends.

Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function.

if getattr(self.vm, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
list(zip(self.maker.expanded_inputs, input_storage))
):
if input.update is not None:
storage.data = outputs.pop()
else:
outputs = outputs[: self.n_returned_outputs]

Originally posted by @ricardoV94 in #1032 (comment)

@ricardoV94 ricardoV94 changed the title Don't force .cpu() on all PyTorch outputs Don't force .cpu() on all PyTorch outputs Oct 29, 2024
@ricardoV94 ricardoV94 added torch PyTorch backend enhancement New feature or request backend compatibility labels Oct 29, 2024
@Ch0ronomato
Copy link
Contributor

I would really like to work on this if possible, it burned me a few times

@ricardoV94
Copy link
Member Author

I would really like to work on this if possible, it burned me a few times

Of course :)

@Ch0ronomato
Copy link
Contributor

Ch0ronomato commented Nov 4, 2024

@ricardoV94 ; i'm wondering if this has overlap with the issue I found when messing around with pymc + py[torch|tensor]: #1065. I guess what I'm wondering is should the linker be smart enough to know when to do

result.detach().numpy()

Then the issue with pymc should un theory be solved

@ricardoV94
Copy link
Member Author

The problem is that fails when the data is on the gpu. Is there a cheap way to know when it is and whet it's not? Just wrap it in a try/except?

@Ch0ronomato
Copy link
Contributor

Yea, x.device gives you the current location of the tensor. As long as we check cpu Its fairly straightforward (gpu device names vary)

@ricardoV94
Copy link
Member Author

Yea, x.device gives you the current location of the tensor. As long as we check cpu Its fairly straightforward (gpu device names vary)

Wanna try that? It's still suboptimal to always force transfer but probably fine for a rough use of the backend. We may allow user control with custom linker settings in the future

@Ch0ronomato
Copy link
Contributor

We would combine this with the suggestion you had earlier as well?

Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function.

@ricardoV94
Copy link
Member Author

We would combine this with the suggestion you had earlier as well?

Perhaps we should expand a bit on the TorchLinker to perform the updates itself, and only force conversion when that's the case. This is already supported by Function.

Let's skip that idea of the updates for now and force everything to be numpy once it's out. Otherwise you'll have the same sort of problems you saw in your PyMC tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend compatibility enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

No branches or pull requests

2 participants