Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible way to resolve gh-1304 #20

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,28 @@ def add(x1,
# at least either x1 or x2 has to be an array
pass
else:
x1_desc, x2_desc = get_descriptors(x1, x2)
# if not dpnp.isscalar(x1):
# x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
# if x1_desc:
# _x2 = dpnp.asarray(x2, dtype=x1.dtype, sycl_queue=x1.sycl_queue, usm_type=x1.usm_type)
# x2_desc = dpnp.get_dpnp_descriptor(_x2, copy_when_strides=False, copy_when_nondefault_queue=False)
# elif not dpnp.isscalar(x1):
# x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)
# if x2_desc:
# _x1 = dpnp.asarray(x1, dtype=x2.dtype, sycl_queue=x2.sycl_queue, usm_type=x2.usm_type)
# x1_desc = dpnp.get_dpnp_descriptor(_x1, copy_when_strides=False, copy_when_nondefault_queue=False)
# else:
# x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
# x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)

# get USM type and queue to copy scalar from the host memory into a USM allocation
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)
# usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
# x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
# alloc_usm_type=usm_type, alloc_queue=queue)
# x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
# alloc_usm_type=usm_type, alloc_queue=queue)
if x1_desc and x2_desc:
return dpnp_add(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()

Expand Down
30 changes: 30 additions & 0 deletions dpnp/dpnp_utils/dpnp_algo_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ __all__ = [
"dpnp_descriptor",
"get_axis_indeces",
"get_axis_offsets",
"get_descriptors",
"get_usm_allocations",
"_get_linear_index",
"map_dtype_to_device",
Expand Down Expand Up @@ -246,6 +247,35 @@ def _get_common_allocation_queue(objects):
return common_queue


def get_descriptors(x1, x2):
"""
TODO

"""

x1_desc, x2_desc = None, None
get_descr = lambda x: dpnp.get_dpnp_descriptor(x, copy_when_strides=False, copy_when_nondefault_queue=False)

if dpnp.isscalar(x1) and dpnp.isscalar(x2):
pass
if not dpnp.isscalar(x1):
print()
print(f"not dpnp.isscalar(x1): dt={x1.dtype}")
x1_desc = get_descr(x1)
if x1_desc:
_x2 = dpnp.asarray(x2, dtype=x1.dtype, sycl_queue=x1.sycl_queue, usm_type=x1.usm_type)
x2_desc = get_descr(_x2)
elif not dpnp.isscalar(x1):
x2_desc = get_descr(x2)
if x2_desc:
_x1 = dpnp.asarray(x1, dtype=x2.dtype, sycl_queue=x2.sycl_queue, usm_type=x2.usm_type)
x1_desc = get_descr(_x1)
else:
x1_desc = get_descr(x1)
x2_desc = get_descr(x2)
return (x1_desc, x2_desc)


def get_usm_allocations(objects):
"""
Given a list of objects returns a tuple of USM type and SYCL queue
Expand Down
6 changes: 2 additions & 4 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,7 @@ def test_from_dlpack(arr_dtype, shape, device):
Y = dpnp.from_dlpack(X)
assert_array_equal(X, Y)
assert X.__dlpack_device__() == Y.__dlpack_device__()
assert X.sycl_device == Y.sycl_device
assert X.sycl_context == Y.sycl_context
assert_sycl_queue_equal(X.sycl_queue, Y.sycl_queue)
assert X.usm_type == Y.usm_type
if Y.ndim:
V = Y[::-1]
Expand All @@ -868,6 +867,5 @@ def test_from_dlpack_with_dpt(arr_dtype, device):
assert_array_equal(X, Y)
assert isinstance(Y, dpnp.dpnp_array.dpnp_array)
assert X.__dlpack_device__() == Y.__dlpack_device__()
assert X.sycl_device == Y.sycl_device
assert X.sycl_context == Y.sycl_context
assert X.usm_type == Y.usm_type
assert_sycl_queue_equal(X.sycl_queue, Y.sycl_queue)