diff --git a/dpnp/dpnp_iface.py b/dpnp/dpnp_iface.py index e37c2e090a6..d8838e67c8d 100644 --- a/dpnp/dpnp_iface.py +++ b/dpnp/dpnp_iface.py @@ -495,17 +495,20 @@ def get_result_array(a, out=None, casting="safe"): if out is None: return a else: - dpnp.check_supported_arrays_type(out) - if out.shape != a.shape: - raise ValueError( - f"Output array of shape {a.shape} is needed, got {out.shape}." - ) - elif isinstance(out, dpt.usm_ndarray): - out = dpnp_array._create_from_usm_ndarray(out) + if a is out: + return out + else: + dpnp.check_supported_arrays_type(out) + if out.shape != a.shape: + raise ValueError( + f"Output array of shape {a.shape} is needed, got {out.shape}." + ) + elif isinstance(out, dpt.usm_ndarray): + out = dpnp_array._create_from_usm_ndarray(out) - dpnp.copyto(out, a, casting=casting) + dpnp.copyto(out, a, casting=casting) - return out + return out def get_usm_ndarray(a): diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index bf1a3417704..3c36eda042d 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -36,6 +36,41 @@ __all__ = ["dpnp_dot", "dpnp_matmul"] +def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): + """ + Create the result array. + + If `out` is not ``None`` and its features match the specified `shape`, `dtype, + `usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and + does not have any memory overlap with `x1` and `x2`, `out` itself is returned. + If these conditions are not statisfied, an empty array is returned with the + specified `shape`, `dtype, `usm_type`, and `sycl_queue`. + """ + + if out is not None: + x1_usm = dpnp.get_usm_ndarray(x1) + x2_usm = dpnp.get_usm_ndarray(x2) + out_usm = dpnp.get_usm_ndarray(out) + + if ( + out.dtype == dtype + and out.shape == shape + and out.usm_type == usm_type + and out.sycl_queue == sycl_queue + and (out.flags.c_contiguous or out.flags.f_contiguous) + and not ti._array_overlap(x1_usm, out_usm) + and not ti._array_overlap(x2_usm, out_usm) + ): + return out + + return dpnp.empty( + shape, + dtype=dtype, + usm_type=usm_type, + sycl_queue=sycl_queue, + ) + + def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None): """ Creating a copy of input array if needed. @@ -214,14 +249,9 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False): a, b, dtype=None, casting="no", sycl_queue=exec_q ) - # create result array - result = dpnp.empty( - (), - dtype=dot_dtype, - usm_type=res_usm_type, - sycl_queue=exec_q, + result = _create_result_array( + a, b, out, (), dot_dtype, res_usm_type, exec_q ) - # input arrays should have the proper data type dep_events_list = [] host_tasks_list = [] @@ -367,13 +397,10 @@ def dpnp_matmul( x2_shape = x2.shape res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1]) - # calculate results - result = dpnp.empty( - res_shape, - dtype=gemm_dtype, - usm_type=res_usm_type, - sycl_queue=exec_q, + result = _create_result_array( + x1, x2, out, res_shape, gemm_dtype, res_usm_type, exec_q ) + # calculate result if result.size == 0: pass elif x1.size == 0 or x2.size == 0: