From 7e790830cd04bdf7c9503d6861a5857d9e626650 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Sun, 5 Nov 2023 16:49:46 -0600 Subject: [PATCH 1/4] Enable use of np.int64 to specify shape of usm_ndarray --- dpctl/tensor/_usmarray.pyx | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index ba18600135..94c3dc7d7c 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -182,13 +182,19 @@ cdef class usm_ndarray: cdef bint is_fp16 = False self._reset() - if (not isinstance(shape, (list, tuple)) - and not hasattr(shape, 'tolist')): - try: - shape - shape = [shape, ] - except Exception: - raise TypeError("Argument shape must be a list or a tuple.") + if not isinstance(shape, (list, tuple)): + if hasattr(shape, 'tolist'): + fn = getattr(shape, 'tolist') + if callable(fn): + shape = shape.tolist() + if not isinstance(shape, (list, tuple)): + try: + shape + shape = [shape, ] + except Exception: + raise TypeError( + "Argument shape must be a list or a tuple." + ) nd = len(shape) if dtype is None: if isinstance(buffer, (dpmem._memory._Memory, usm_ndarray)): From 2bc793923ddaaadc6bc8e05a72710943137a8432 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Sun, 5 Nov 2023 16:50:33 -0600 Subject: [PATCH 2/4] Add a test for shape being np.int64 scalar --- dpctl/tests/test_usm_ndarray_ctor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index 72f5aabebb..095bbc5638 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -39,6 +39,7 @@ (2, 5, 2), (2, 2, 2, 2, 2, 2, 2, 2), 5, + np.int32(7), ], ) @pytest.mark.parametrize("usm_type", ["shared", "host", "device"]) From aadb6b4aea27f3d2cb0144520076336af61b3cb2 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Sun, 5 Nov 2023 16:50:52 -0600 Subject: [PATCH 3/4] Eliminated multiple uses of same literal constants in test_search_reduction_kernels --- dpctl/tests/test_usm_ndarray_reductions.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_reductions.py b/dpctl/tests/test_usm_ndarray_reductions.py index cbfd6baec6..0969822e6d 100644 --- a/dpctl/tests/test_usm_ndarray_reductions.py +++ b/dpctl/tests/test_usm_ndarray_reductions.py @@ -175,9 +175,11 @@ def test_search_reduction_kernels(arg_dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(arg_dtype, q) - x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q) + x_shape = (24, 1024) + x_size = np.prod(x_shape) + x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q) idx = randrange(x.size) - idx_tup = np.unravel_index(idx, (24, 1025)) + idx_tup = np.unravel_index(idx, x_shape) x[idx] = 2 m = dpt.argmax(x) @@ -194,7 +196,7 @@ def test_search_reduction_kernels(arg_dtype): m = dpt.argmax(y) assert m == 2 * idx - x = dpt.reshape(x, (24, 1025)) + x = dpt.reshape(x, x_shape) x[idx_tup[0], :] = 3 m = dpt.argmax(x, axis=0) @@ -209,15 +211,15 @@ def test_search_reduction_kernels(arg_dtype): m = dpt.argmax(x, axis=1) assert dpt.all(m == idx) - x = dpt.ones((24 * 1025), dtype=arg_dtype, sycl_queue=q) + x = dpt.ones(x_size, dtype=arg_dtype, sycl_queue=q) idx = randrange(x.size) - idx_tup = np.unravel_index(idx, (24, 1025)) + idx_tup = np.unravel_index(idx, x_shape) x[idx] = 0 m = dpt.argmin(x) assert m == idx - x = dpt.reshape(x, (24, 1025)) + x = dpt.reshape(x, x_shape) x[idx_tup[0], :] = -1 m = dpt.argmin(x, axis=0) From da594763acb042330e723d160258dabc570ff964 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Tue, 7 Nov 2023 15:32:57 -0600 Subject: [PATCH 4/4] Changed TypeError wording per PR feedback --- dpctl/tensor/_usmarray.pyx | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 94c3dc7d7c..5b394d971b 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -191,10 +191,11 @@ cdef class usm_ndarray: try: shape shape = [shape, ] - except Exception: + except Exception as e: raise TypeError( - "Argument shape must be a list or a tuple." - ) + "Argument shape must a non-negative integer, " + "or a list/tuple of such integers." + ) from e nd = len(shape) if dtype is None: if isinstance(buffer, (dpmem._memory._Memory, usm_ndarray)):