From e0b79ff0398ca2fbdb0e7a9c738edce01a10653a Mon Sep 17 00:00:00 2001 From: Natalia Polina Date: Mon, 6 Feb 2023 05:23:40 -0600 Subject: [PATCH 1/2] Added missing usm_type to tril() and triu() functions. --- dpctl/tensor/_ctors.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index aae31fddbf..2de9c979bb 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -1247,7 +1247,11 @@ def tril(X, k=0): if k >= shape[nd - 1] - 1: res = dpt.empty( - X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + X.shape, + dtype=X.dtype, + order=order, + usm_type=X.usm_type, + sycl_queue=X.sycl_queue, ) hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( src=X, dst=res, sycl_queue=X.sycl_queue @@ -1255,11 +1259,19 @@ def tril(X, k=0): hev.wait() elif k < -shape[nd - 2]: res = dpt.zeros( - X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + X.shape, + dtype=X.dtype, + order=order, + usm_type=X.usm_type, + sycl_queue=X.sycl_queue, ) else: res = dpt.empty( - X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + X.shape, + dtype=X.dtype, + order=order, + usm_type=X.usm_type, + sycl_queue=X.sycl_queue, ) hev, _ = ti._tril(src=X, dst=res, k=k, sycl_queue=X.sycl_queue) hev.wait() @@ -1290,11 +1302,19 @@ def triu(X, k=0): if k > shape[nd - 1]: res = dpt.zeros( - X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + X.shape, + dtype=X.dtype, + order=order, + usm_type=X.usm_type, + sycl_queue=X.sycl_queue, ) elif k <= -shape[nd - 2] + 1: res = dpt.empty( - X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + X.shape, + dtype=X.dtype, + order=order, + usm_type=X.usm_type, + sycl_queue=X.sycl_queue, ) hev, _ = ti._copy_usm_ndarray_into_usm_ndarray( src=X, dst=res, sycl_queue=X.sycl_queue @@ -1302,7 +1322,11 @@ def triu(X, k=0): hev.wait() else: res = dpt.empty( - X.shape, dtype=X.dtype, order=order, sycl_queue=X.sycl_queue + X.shape, + dtype=X.dtype, + order=order, + usm_type=X.usm_type, + sycl_queue=X.sycl_queue, ) hev, _ = ti._triu(src=X, dst=res, k=k, sycl_queue=X.sycl_queue) hev.wait() From ec8509f8271de9de2584a49a81e5f1939c3a094e Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 6 Feb 2023 06:17:01 -0600 Subject: [PATCH 2/2] Added tests for tril/triu usm_type, queue --- dpctl/tests/test_usm_ndarray_ctor.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index a60a085557..4534959a52 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -1503,6 +1503,28 @@ def test_triu(dtype): assert np.array_equal(Ynp, dpt.asnumpy(Y)) +@pytest.mark.parametrize("tri_fn", [dpt.tril, dpt.triu]) +@pytest.mark.parametrize("usm_type", ["device", "shared", "host"]) +def test_tri_usm_type(tri_fn, usm_type): + q = get_queue_or_skip() + dtype = dpt.uint16 + + shape = (2, 3, 4, 5, 5) + size = np.prod(shape) + X = dpt.reshape( + dpt.arange(size, dtype=dtype, usm_type=usm_type, sycl_queue=q), shape + ) + Y = tri_fn(X) # main execution branch + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == q + Y = tri_fn(X, k=-6) # special case of Y == X + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == q + Y = tri_fn(X, k=6) # special case of Y == 0 + assert Y.usm_type == X.usm_type + assert Y.sycl_queue == q + + def test_tril_slice(): q = get_queue_or_skip()