From 54f3202fa2f842b95a221c757bcbc7e9f677a50c Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 6 Feb 2024 11:03:42 -0600 Subject: [PATCH 1/5] use out keyword for result --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 38 ++++++++++++++++++--- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 65d97befa98..ac181d80260 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -219,6 +219,17 @@ def dpnp_dot(a, b, /, out=None): usm_type=res_usm_type, sycl_queue=exec_q, ) + out_is_used = False + if out is not None: + dpnp.check_supported_arrays_type(out) + if ( + out.dtype == dot_dtype + and out.shape == () + and out.usm_type == res_usm_type + and out.sycl_queue == exec_q + ): + result = out + out_is_used = True # input arrays should have the proper data type dep_events_list = [] @@ -253,8 +264,11 @@ def dpnp_dot(a, b, /, out=None): if dot_dtype != res_dtype: result = result.astype(res_dtype, copy=False) - # NumPy does not allow casting even if it is safe - return dpnp.get_result_array(result, out, casting="no") + if out_is_used: + return out + else: + # NumPy does not allow casting even if it is safe + return dpnp.get_result_array(result, out, casting="no") def dpnp_matmul( @@ -361,13 +375,26 @@ def dpnp_matmul( x2_shape = x2.shape res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1]) - # calculate results + # create result array result = dpnp.empty( res_shape, dtype=gemm_dtype, usm_type=res_usm_type, sycl_queue=exec_q, ) + out_is_used = False + if out is not None: + dpnp.check_supported_arrays_type(out) + if ( + out.dtype == gemm_dtype + and out.shape == res_shape + and out.usm_type == res_usm_type + and out.sycl_queue == exec_q + ): + result = out + out_is_used = True + + # calculate result if result.size == 0: pass elif x1.size == 0 or x2.size == 0: @@ -446,4 +473,7 @@ def dpnp_matmul( else: return result else: - return dpnp.get_result_array(result, out, casting=casting) + if out_is_used: + return out + else: + return dpnp.get_result_array(result, out, casting=casting) From bdc118bb21116c297e2a41c33779eb5a45fcfb6e Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 6 Feb 2024 14:02:27 -0600 Subject: [PATCH 2/5] fix strided or overlapping out --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index ac181d80260..9d5fa309fd1 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -222,11 +222,17 @@ def dpnp_dot(a, b, /, out=None): out_is_used = False if out is not None: dpnp.check_supported_arrays_type(out) + a_usm = dpnp.get_usm_ndarray(a) + b_usm = dpnp.get_usm_ndarray(b) + out_usm = dpnp.get_usm_ndarray(out) + if ( out.dtype == dot_dtype and out.shape == () and out.usm_type == res_usm_type and out.sycl_queue == exec_q + and not ti._array_overlap(a_usm, out_usm) + and not ti._array_overlap(b_usm, out_usm) ): result = out out_is_used = True @@ -382,14 +388,22 @@ def dpnp_matmul( usm_type=res_usm_type, sycl_queue=exec_q, ) + out_is_used = False if out is not None: dpnp.check_supported_arrays_type(out) + x1_usm = dpnp.get_usm_ndarray(x1) + x2_usm = dpnp.get_usm_ndarray(x2) + out_usm = dpnp.get_usm_ndarray(out) + if ( out.dtype == gemm_dtype and out.shape == res_shape and out.usm_type == res_usm_type and out.sycl_queue == exec_q + and (out.flags.c_contiguous or out.flags.f_contiguous) + and not ti._array_overlap(x1_usm, out_usm) + and not ti._array_overlap(x2_usm, out_usm) ): result = out out_is_used = True From 5aaad71bf3a47f21e88173bdb741aab7940dff37 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 6 Feb 2024 17:36:14 -0600 Subject: [PATCH 3/5] address comments --- dpnp/dpnp_iface.py | 21 ++-- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 102 ++++++++------------ 2 files changed, 54 insertions(+), 69 deletions(-) diff --git a/dpnp/dpnp_iface.py b/dpnp/dpnp_iface.py index 9aee27b73bc..d0ffc75242b 100644 --- a/dpnp/dpnp_iface.py +++ b/dpnp/dpnp_iface.py @@ -484,17 +484,20 @@ def get_result_array(a, out=None, casting="safe"): if out is None: return a else: - dpnp.check_supported_arrays_type(out) - if out.shape != a.shape: - raise ValueError( - f"Output array of shape {a.shape} is needed, got {out.shape}." - ) - elif isinstance(out, dpt.usm_ndarray): - out = dpnp_array._create_from_usm_ndarray(out) + if a is out: + return out + else: + dpnp.check_supported_arrays_type(out) + if out.shape != a.shape: + raise ValueError( + f"Output array of shape {a.shape} is needed, got {out.shape}." + ) + elif isinstance(out, dpt.usm_ndarray): + out = dpnp_array._create_from_usm_ndarray(out) - dpnp.copyto(out, a, casting=casting) + dpnp.copyto(out, a, casting=casting) - return out + return out def get_usm_ndarray(a): diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 9d5fa309fd1..427b4de0aef 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -36,6 +36,42 @@ __all__ = ["dpnp_dot", "dpnp_matmul"] +def _create_resul_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): + """ + Create the result array. + + If `out` is not ``None`` and its features match the specified `shape`, `dtype, + `usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and + does not have any memory overlap with `x1` and `x2`, `out` itself is returned. + If these conditions are not statisfied, an empty array is returned with the + specified `shape`, `dtype, `usm_type`, and `sycl_queue`. + """ + + if out is not None: + dpnp.check_supported_arrays_type(out) + x1_usm = dpnp.get_usm_ndarray(x1) + x2_usm = dpnp.get_usm_ndarray(x2) + out_usm = dpnp.get_usm_ndarray(out) + + if ( + out.dtype == dtype + and out.shape == shape + and out.usm_type == usm_type + and out.sycl_queue == sycl_queue + and (out.flags.c_contiguous or out.flags.f_contiguous) + and not ti._array_overlap(x1_usm, out_usm) + and not ti._array_overlap(x2_usm, out_usm) + ): + return out + + return dpnp.empty( + shape, + dtype=dtype, + usm_type=usm_type, + sycl_queue=sycl_queue, + ) + + def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): """ Creating a copy of input array if needed. @@ -212,31 +248,7 @@ def dpnp_dot(a, b, /, out=None): a, b, dtype=None, casting="no", sycl_queue=exec_q ) - # create result array - result = dpnp.empty( - (), - dtype=dot_dtype, - usm_type=res_usm_type, - sycl_queue=exec_q, - ) - out_is_used = False - if out is not None: - dpnp.check_supported_arrays_type(out) - a_usm = dpnp.get_usm_ndarray(a) - b_usm = dpnp.get_usm_ndarray(b) - out_usm = dpnp.get_usm_ndarray(out) - - if ( - out.dtype == dot_dtype - and out.shape == () - and out.usm_type == res_usm_type - and out.sycl_queue == exec_q - and not ti._array_overlap(a_usm, out_usm) - and not ti._array_overlap(b_usm, out_usm) - ): - result = out - out_is_used = True - + result = _create_resul_array(a, b, out, (), dot_dtype, res_usm_type, exec_q) # input arrays should have the proper data type dep_events_list = [] host_tasks_list = [] @@ -270,11 +282,8 @@ def dpnp_dot(a, b, /, out=None): if dot_dtype != res_dtype: result = result.astype(res_dtype, copy=False) - if out_is_used: - return out - else: - # NumPy does not allow casting even if it is safe - return dpnp.get_result_array(result, out, casting="no") + # NumPy does not allow casting even if it is safe + return dpnp.get_result_array(result, out, casting="no") def dpnp_matmul( @@ -381,33 +390,9 @@ def dpnp_matmul( x2_shape = x2.shape res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1]) - # create result array - result = dpnp.empty( - res_shape, - dtype=gemm_dtype, - usm_type=res_usm_type, - sycl_queue=exec_q, + result = _create_resul_array( + x1, x2, out, res_shape, gemm_dtype, res_usm_type, exec_q ) - - out_is_used = False - if out is not None: - dpnp.check_supported_arrays_type(out) - x1_usm = dpnp.get_usm_ndarray(x1) - x2_usm = dpnp.get_usm_ndarray(x2) - out_usm = dpnp.get_usm_ndarray(out) - - if ( - out.dtype == gemm_dtype - and out.shape == res_shape - and out.usm_type == res_usm_type - and out.sycl_queue == exec_q - and (out.flags.c_contiguous or out.flags.f_contiguous) - and not ti._array_overlap(x1_usm, out_usm) - and not ti._array_overlap(x2_usm, out_usm) - ): - result = out - out_is_used = True - # calculate result if result.size == 0: pass @@ -487,7 +472,4 @@ def dpnp_matmul( else: return result else: - if out_is_used: - return out - else: - return dpnp.get_result_array(result, out, casting=casting) + return dpnp.get_result_array(result, out, casting=casting) From f8902317e0912fcf0b3e7c65231c4378818e9140 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Tue, 6 Feb 2024 20:21:08 -0600 Subject: [PATCH 4/5] fix typo --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 427b4de0aef..44666a8de95 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -36,7 +36,7 @@ __all__ = ["dpnp_dot", "dpnp_matmul"] -def _create_resul_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): +def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): """ Create the result array. @@ -248,7 +248,9 @@ def dpnp_dot(a, b, /, out=None): a, b, dtype=None, casting="no", sycl_queue=exec_q ) - result = _create_resul_array(a, b, out, (), dot_dtype, res_usm_type, exec_q) + result = _create_result_array( + a, b, out, (), dot_dtype, res_usm_type, exec_q + ) # input arrays should have the proper data type dep_events_list = [] host_tasks_list = [] @@ -390,7 +392,7 @@ def dpnp_matmul( x2_shape = x2.shape res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1]) - result = _create_resul_array( + result = _create_result_array( x1, x2, out, res_shape, gemm_dtype, res_usm_type, exec_q ) # calculate result From d6bfa622388a2a17a2ae32d30bcc541f356a8b91 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 7 Feb 2024 09:00:40 -0600 Subject: [PATCH 5/5] remove additional check --- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 44666a8de95..c953a49ceeb 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -48,7 +48,6 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): """ if out is not None: - dpnp.check_supported_arrays_type(out) x1_usm = dpnp.get_usm_ndarray(x1) x2_usm = dpnp.get_usm_ndarray(x2) out_usm = dpnp.get_usm_ndarray(out)