Skip to content

__cuda_array_interface__ conversion does not support readonly arrays #32868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
hawkinsp opened this issue Jan 31, 2020 · 1 comment
Open

__cuda_array_interface__ conversion does not support readonly arrays #32868

hawkinsp opened this issue Jan 31, 2020 · 1 comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: numba triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@hawkinsp
Copy link

hawkinsp commented Jan 31, 2020

PyTorch does not support importing readonly GPU arrays via __cuda_array_interface__:
https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html

throw TypeError("the read only flag is not supported, should always be False");

To Reproduce

If you have a copy of jax and jaxlib with GPU support built from head, the following will reproduce:

In [1]: import torch, jax, jax.numpy as jnp

In [2]: x = jnp.array([1,2,3])

In [3]: y = torch.as_tensor(x)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-f43755a1deef> in <module>
----> 1 y = torch.tensor(x)

TypeError: the read only flag is not supported, should always be False

(Aside: this is not a PyTorch bug, but curiously CuPy drops the readonly flag, so you can make the import "work" by laundering the array through CuPy:

In [1]: import torch, jax, jax.numpy as jnp, cupy

In [2]: x = jnp.array([1,2,3])

In [3]: y = torch.as_tensor(cupy.asarray(x), device="cuda")

In [4]: x.__cuda_array_interface__
Out[4]:
{'shape': (3,),
 'typestr': '<i4',
 'data': (140492215944704, True),
 'version': 2}

In [5]: y.__cuda_array_interface__
Out[5]:
{'typestr': '<i4',
 'shape': (3,),
 'strides': (4,),
 'data': (140492215944704, False),
 'version': 1}
)

Expected behavior

PyTorch should support the readonly flag.

cc @ngimel

@ezyang ezyang added module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: numba module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Feb 3, 2020
@michaelklachko
Copy link

Sorry for off-topic, but I'm curious about y = torch.as_tensor(cupy.asarray(x), device="cuda") line - I was looking for a way to pass torch cuda tensors to cupy for some ops not available in pytorch, then passing them back to pytorch all while keeping the data on the same GPU. I thought the best way to do this is to use DLPack (e.g. see this cupy page), but your example seems to indicate that DLPack is unnecessary. Is there any reason to use DLPack in this scenario?

@mruberry mruberry removed the module: numpy Related to numpy support, and also numpy compatibility of our operators label May 11, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: numba triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants