Skip to content

Commit

Permalink
Improve performance of dpnp.matmul and dpnp.dot with out keyword (
Browse files Browse the repository at this point in the history
#1694)

* use out keyword for result

* fix strided or overlapping out

* address comments

* fix typo

* remove additional check
  • Loading branch information
vtavana authored Feb 8, 2024
1 parent 1a3866e commit d45bb24
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 22 deletions.
21 changes: 12 additions & 9 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,17 +495,20 @@ def get_result_array(a, out=None, casting="safe"):
if out is None:
return a
else:
dpnp.check_supported_arrays_type(out)
if out.shape != a.shape:
raise ValueError(
f"Output array of shape {a.shape} is needed, got {out.shape}."
)
elif isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)
if a is out:
return out
else:
dpnp.check_supported_arrays_type(out)
if out.shape != a.shape:
raise ValueError(
f"Output array of shape {a.shape} is needed, got {out.shape}."
)
elif isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)

dpnp.copyto(out, a, casting=casting)
dpnp.copyto(out, a, casting=casting)

return out
return out


def get_usm_ndarray(a):
Expand Down
53 changes: 40 additions & 13 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,41 @@
__all__ = ["dpnp_dot", "dpnp_matmul"]


def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
"""
Create the result array.
If `out` is not ``None`` and its features match the specified `shape`, `dtype,
`usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and
does not have any memory overlap with `x1` and `x2`, `out` itself is returned.
If these conditions are not statisfied, an empty array is returned with the
specified `shape`, `dtype, `usm_type`, and `sycl_queue`.
"""

if out is not None:
x1_usm = dpnp.get_usm_ndarray(x1)
x2_usm = dpnp.get_usm_ndarray(x2)
out_usm = dpnp.get_usm_ndarray(out)

if (
out.dtype == dtype
and out.shape == shape
and out.usm_type == usm_type
and out.sycl_queue == sycl_queue
and (out.flags.c_contiguous or out.flags.f_contiguous)
and not ti._array_overlap(x1_usm, out_usm)
and not ti._array_overlap(x2_usm, out_usm)
):
return out

return dpnp.empty(
shape,
dtype=dtype,
usm_type=usm_type,
sycl_queue=sycl_queue,
)


def _copy_array(x, dep_events, host_events, contig_copy=False, dtype=None):
"""
Creating a copy of input array if needed.
Expand Down Expand Up @@ -214,14 +249,9 @@ def dpnp_dot(a, b, /, out=None, *, conjugate=False):
a, b, dtype=None, casting="no", sycl_queue=exec_q
)

# create result array
result = dpnp.empty(
(),
dtype=dot_dtype,
usm_type=res_usm_type,
sycl_queue=exec_q,
result = _create_result_array(
a, b, out, (), dot_dtype, res_usm_type, exec_q
)

# input arrays should have the proper data type
dep_events_list = []
host_tasks_list = []
Expand Down Expand Up @@ -367,13 +397,10 @@ def dpnp_matmul(
x2_shape = x2.shape
res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1])

# calculate results
result = dpnp.empty(
res_shape,
dtype=gemm_dtype,
usm_type=res_usm_type,
sycl_queue=exec_q,
result = _create_result_array(
x1, x2, out, res_shape, gemm_dtype, res_usm_type, exec_q
)
# calculate result
if result.size == 0:
pass
elif x1.size == 0 or x2.size == 0:
Expand Down

0 comments on commit d45bb24

Please sign in to comment.