Skip to content

__cuda_array_interface__ conversion does not support readonly arrays #32868

Open
@hawkinsp

Description

@hawkinsp

PyTorch does not support importing readonly GPU arrays via __cuda_array_interface__:
https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html

throw TypeError("the read only flag is not supported, should always be False");

To Reproduce

If you have a copy of jax and jaxlib with GPU support built from head, the following will reproduce:

In [1]: import torch, jax, jax.numpy as jnp

In [2]: x = jnp.array([1,2,3])

In [3]: y = torch.as_tensor(x)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-f43755a1deef> in <module>
----> 1 y = torch.tensor(x)

TypeError: the read only flag is not supported, should always be False

(Aside: this is not a PyTorch bug, but curiously CuPy drops the readonly flag, so you can make the import "work" by laundering the array through CuPy:

In [1]: import torch, jax, jax.numpy as jnp, cupy

In [2]: x = jnp.array([1,2,3])

In [3]: y = torch.as_tensor(cupy.asarray(x), device="cuda")

In [4]: x.__cuda_array_interface__
Out[4]:
{'shape': (3,),
 'typestr': '<i4',
 'data': (140492215944704, True),
 'version': 2}

In [5]: y.__cuda_array_interface__
Out[5]:
{'typestr': '<i4',
 'shape': (3,),
 'strides': (4,),
 'data': (140492215944704, False),
 'version': 1}
)

Expected behavior

PyTorch should support the readonly flag.

cc @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: internalsRelated to internal abstractions in c10 and ATenmodule: numbatriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions