Open
Description
PyTorch does not support importing readonly GPU arrays via __cuda_array_interface__
:
https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
pytorch/torch/csrc/utils/tensor_numpy.cpp
Line 292 in 2471ddc
To Reproduce
If you have a copy of jax and jaxlib with GPU support built from head, the following will reproduce:
In [1]: import torch, jax, jax.numpy as jnp
In [2]: x = jnp.array([1,2,3])
In [3]: y = torch.as_tensor(x)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-f43755a1deef> in <module>
----> 1 y = torch.tensor(x)
TypeError: the read only flag is not supported, should always be False
(Aside: this is not a PyTorch bug, but curiously CuPy drops the readonly flag, so you can make the import "work" by laundering the array through CuPy:
In [1]: import torch, jax, jax.numpy as jnp, cupy
In [2]: x = jnp.array([1,2,3])
In [3]: y = torch.as_tensor(cupy.asarray(x), device="cuda")
In [4]: x.__cuda_array_interface__
Out[4]:
{'shape': (3,),
'typestr': '<i4',
'data': (140492215944704, True),
'version': 2}
In [5]: y.__cuda_array_interface__
Out[5]:
{'typestr': '<i4',
'shape': (3,),
'strides': (4,),
'data': (140492215944704, False),
'version': 1}
)
Expected behavior
PyTorch should support the readonly flag.
cc @ngimel