Skip to content

Commit

Permalink
Merge pull request #1467 from IntelPython/fix-usm-ndarray-ctor-when-s…
Browse files Browse the repository at this point in the history
…hape-is-integral-numpy-scalar

Fix usm_ndarray ctor when shape is integral numpy scalar
  • Loading branch information
oleksandr-pavlyk authored Nov 8, 2023
2 parents dbab3fe + da59476 commit f686102
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
21 changes: 14 additions & 7 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,20 @@ cdef class usm_ndarray:
cdef bint is_fp16 = False

self._reset()
if (not isinstance(shape, (list, tuple))
and not hasattr(shape, 'tolist')):
try:
<Py_ssize_t> shape
shape = [shape, ]
except Exception:
raise TypeError("Argument shape must be a list or a tuple.")
if not isinstance(shape, (list, tuple)):
if hasattr(shape, 'tolist'):
fn = getattr(shape, 'tolist')
if callable(fn):
shape = shape.tolist()
if not isinstance(shape, (list, tuple)):
try:
<Py_ssize_t> shape
shape = [shape, ]
except Exception as e:
raise TypeError(
"Argument shape must a non-negative integer, "
"or a list/tuple of such integers."
) from e
nd = len(shape)
if dtype is None:
if isinstance(buffer, (dpmem._memory._Memory, usm_ndarray)):
Expand Down
1 change: 1 addition & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
(2, 5, 2),
(2, 2, 2, 2, 2, 2, 2, 2),
5,
np.int32(7),
],
)
@pytest.mark.parametrize("usm_type", ["shared", "host", "device"])
Expand Down
14 changes: 8 additions & 6 deletions dpctl/tests/test_usm_ndarray_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,11 @@ def test_search_reduction_kernels(arg_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(arg_dtype, q)

x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q)
x_shape = (24, 1024)
x_size = np.prod(x_shape)
x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q)
idx = randrange(x.size)
idx_tup = np.unravel_index(idx, (24, 1025))
idx_tup = np.unravel_index(idx, x_shape)
x[idx] = 2

m = dpt.argmax(x)
Expand All @@ -194,7 +196,7 @@ def test_search_reduction_kernels(arg_dtype):
m = dpt.argmax(y)
assert m == 2 * idx

x = dpt.reshape(x, (24, 1025))
x = dpt.reshape(x, x_shape)

x[idx_tup[0], :] = 3
m = dpt.argmax(x, axis=0)
Expand All @@ -209,15 +211,15 @@ def test_search_reduction_kernels(arg_dtype):
m = dpt.argmax(x, axis=1)
assert dpt.all(m == idx)

x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q)
x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q)
idx = randrange(x.size)
idx_tup = np.unravel_index(idx, (24, 1025))
idx_tup = np.unravel_index(idx, x_shape)
x[idx] = 0

m = dpt.argmin(x)
assert m == idx

x = dpt.reshape(x, (24, 1025))
x = dpt.reshape(x, x_shape)

x[idx_tup[0], :] = -1
m = dpt.argmin(x, axis=0)
Expand Down

0 comments on commit f686102

Please sign in to comment.