Skip to content

Commit

Permalink
support numelkernel for int8 and update cuda_array_interface to v2 an…
Browse files Browse the repository at this point in the history
…d add unitest for cuda_array_interface
  • Loading branch information
HydrogenSulfate committed Sep 13, 2024
1 parent 8e40e50 commit 88eddaa
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 8 deletions.
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/numel_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(numel,
CPU,
ALL_LAYOUT,
phi::NumelKernel,
int8_t,
uint8_t,
int16_t,
int,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/numel_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ PD_REGISTER_KERNEL(numel,
GPU,
ALL_LAYOUT,
phi::NumelKernel,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
Expand Down
34 changes: 26 additions & 8 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,32 +1253,37 @@ def __cuda_array_interface__(self):
"""Array view description for cuda tensors.
See:
https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
CUDA Array Interface (Version 2)
https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
"""

# raise AttributeError for unsupported tensors, so that
# hasattr(cpu_tensor, "__cuda_array_interface__") is False.
if "gpu" not in str(self.place):
raise AttributeError(
f"Can't get __cuda_array_interface__ on non-CUDA tensor type: {self.type()} "
"Can't get __cuda_array_interface__ on non-CUDA tensor. "
"If CUDA data is required use tensor.cuda() to copy tensor to device memory."
)

if self.is_sparse():
raise AttributeError(
f"Can't get __cuda_array_interface__ on sparse type: {self.type()} "
"Can't get __cuda_array_interface__ on sparse tensor. "
"Use Tensor.to_dense() to convert to a dense tensor first."
)

# RuntimeError, matching tensor.__array__() behavior.
if not self.stop_gradient:
raise RuntimeError(
"Can't get __cuda_array_interface__ on Variable that stop_gradient=False. "
"If gradients aren't required, use var.detach() to get Variable that doesn't require grad."
"Can't get __cuda_array_interface__ on Tensor that requires grad. "
"If gradients aren't required, use var.detach() to get Tensor that doesn't require grad."
)

# CUDA devices are little-endian and tensors are stored in native byte
# order. 1-byte entries are endian-agnostic.
typestr = {
paddle.complex64: "<c8",
paddle.complex128: "<c16",
paddle.bfloat16: "<f2",
paddle.float16: "<f2",
paddle.float32: "<f4",
paddle.float64: "<f8",
Expand All @@ -1287,20 +1292,33 @@ def __cuda_array_interface__(self):
paddle.int16: "<i2",
paddle.int32: "<i4",
paddle.int64: "<i8",
paddle.bool: "|b1",
# NOTE: Paddle not support uint32, uint64, uint16 yet.
# paddle.uint16: "<u2",
# paddle.uint32: "<u4",
# paddle.uint64: "<u8",
}[self.dtype]

itemsize = self.element_size()

shape = tuple(self.shape)
strides = tuple(s * itemsize for s in self.strides)
data = (self.data_ptr(), False) # read-only is false
if self.is_contiguous():
# __cuda_array_interface__ v2 requires the strides to be omitted
# (either not set or set to None) for C-contiguous arrays.
strides = None
else:
# the number of bytes to skip to access the next element at each dimension.
strides = tuple(s * itemsize for s in self.strides)

data_ptr = self.data_ptr() if self.numel().item() > 0 else 0
data = (data_ptr, False) # read-only is false

return {
"typestr": typestr,
"shape": shape,
"strides": strides,
"data": data,
"version": 0,
"version": 2,
}

if not hasattr(core, "eager"):
Expand Down
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_tensor_attr_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
'tolist',
'value',
'zero_',
"__cuda_array_interface__",
]
)
STATIC_ONLY_TENSOR_ATTRS_ALLOW_LIST = OrderedSet(
Expand Down
91 changes: 91 additions & 0 deletions test/legacy_test/test_eager_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest

import numpy as np
from utils import dygraph_guard

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -1187,6 +1188,96 @@ def test_print_tensor_dtype(self):

self.assertEqual(a_str, expected)

def test___cuda_array_interface__(self):
"""test Tensor.__cuda_array_interface__"""
with dygraph_guard():
# raise AttributeError for cpu tensor.
cpu_place = paddle.CPUPlace()
cpu_tensor = paddle.rand([3, 3]).to(device=cpu_place)
self.assertRaises(
AttributeError,
getattr,
cpu_tensor,
'__cuda_array_interface__',
)

if paddle.device.is_compiled_with_cuda():
gpu_place = paddle.CUDAPlace(0)
# raise AttributeError for sparse tensor.
sparse_tensor = (
paddle.rand([3, 3]).to(device=gpu_place).to_sparse_coo(2)
)
self.assertRaises(
AttributeError,
getattr,
sparse_tensor,
'__cuda_array_interface__',
)

# strides should be None if contiguous
tensor = paddle.randn([3, 3]).to(device=gpu_place)
interface = tensor.__cuda_array_interface__
assert interface["strides"] is None

# strides should be tuple of int if not contiguous
tensor = paddle.randn([10, 10]).to(device=gpu_place)
tensor = tensor[::2]
interface = tensor.__cuda_array_interface__
assert interface["strides"] == (80, 4)

# data_ptr should be 0 if tensor is 0-size
tensor = paddle.randn([0, 10]).to(device=gpu_place)
interface = tensor.__cuda_array_interface__
assert interface["data"][0] == 0

# raise AttributeError for tensor that requires grad.
tensor = paddle.randn([3, 3]).to(device=gpu_place)
tensor.stop_gradient = False
self.assertRaises(
RuntimeError,
getattr,
tensor,
'__cuda_array_interface__',
)

# check supports of dtypes
for dtype in [
paddle.complex64,
paddle.complex128,
paddle.bfloat16,
paddle.float16,
paddle.float32,
paddle.float64,
paddle.uint8,
paddle.int8,
paddle.int16,
paddle.int32,
paddle.int64,
paddle.bool,
]:
tensor = (
paddle.uniform([10, 10], min=-10.0, max=10.0)
.to(device=gpu_place)
.astype(dtype)
)
interface = tensor.__cuda_array_interface__
assert "typestr" in interface and isinstance(
interface["typestr"], str
)
assert "shape" in interface and isinstance(
interface["shape"], tuple
)
assert "strides" in interface and (
isinstance(interface["strides"], tuple)
or interface["strides"] is None
)
assert (
"data" in interface
and isinstance(interface["data"], tuple)
and len(interface["data"]) == 2
)
assert "version" in interface and interface["version"] == 2


class TestEagerTensorSetitem(unittest.TestCase):
def func_setUp(self):
Expand Down
12 changes: 12 additions & 0 deletions test/legacy_test/test_numel_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def init(self):
self.shape = (0,)


class TestNumelOp1int8(TestNumelOp):
def init(self):
self.dtype = np.int8
self.shape = (11, 66)


class TestNumelOp2int8(TestNumelOp):
def init(self):
self.dtype = np.int8
self.shape = (0,)


class TestNumelOpComplex(TestNumelOp):
def setUp(self):
self.op_type = "size"
Expand Down

0 comments on commit 88eddaa

Please sign in to comment.