Skip to content

Commit

Permalink
First implementation to resolve gh-1134
Browse files Browse the repository at this point in the history
Examples:

```
import dpctl.tensor as dpt

m = dpt.ones((2,4), dtype='i4')
w = dpt.zeros(4)
v = dpt.full(4, -1)

ar = dpt.asarray([m, [w, v]])
ar2 = dpt.asarray([m, [w, v]], device='cpu')
```
  • Loading branch information
oleksandr-pavlyk committed Mar 29, 2023
1 parent 068f65c commit 7f0f48a
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 15 deletions.
132 changes: 117 additions & 15 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.utils
from dpctl.tensor._device import normalize_queue_device
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol

__doc__ = "Implementation of creation functions in :module:`dpctl.tensor`"

Expand Down Expand Up @@ -276,17 +277,6 @@ def _asarray_from_numpy_ndarray(
return res


def _is_object_with_buffer_protocol(obj):
"Returns `True` if object support Python buffer protocol"
try:
# use context manager to ensure
# buffer is instantly released
with memoryview(obj):
return True
except TypeError:
return False


def _ensure_native_dtype_device_support(dtype, dev) -> None:
"""Check that dtype is natively supported by device.
Expand Down Expand Up @@ -318,6 +308,106 @@ def _ensure_native_dtype_device_support(dtype, dev) -> None:
)


def _usm_types_walker(o, usm_types_list):
if isinstance(o, dpt.usm_ndarray):
usm_types_list.append(o.usm_type)
return
if isinstance(o, (list, tuple)):
for el in o:
_usm_types_walker(el, usm_types_list)
return
raise TypeError


def _device_copy_walker(seq_o, res, events):
if isinstance(seq_o, dpt.usm_ndarray):
exec_q = res.sycl_queue
ht_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=seq_o, dst=res, sycl_queue=exec_q
)
events.append(ht_ev)
return
if isinstance(seq_o, (list, tuple)):
for i, el in enumerate(seq_o):
_device_copy_walker(el, res[i], events)
return
raise TypeError


def _copy_through_host_walker(seq_o, usm_res):
if isinstance(seq_o, dpt.usm_ndarray):
usm_res[...] = dpt.asnumpy(seq_o).copy()
return
if isinstance(seq_o, (list, tuple)):
for i, el in enumerate(seq_o):
_copy_through_host_walker(el, usm_res[i])
return
usm_res[...] = np.asarray(seq_o)


def _asarray_from_seq(
seq_obj,
seq_shape,
seq_dt,
seq_dev,
dtype=None,
usm_type=None,
sycl_queue=None,
order="C",
):
"`obj` is a sequence"
if usm_type is None:
usm_types_in_seq = []
_usm_types_walker(seq_obj, usm_types_in_seq)
usm_type = dpctl.utils.get_coerced_usm_type(usm_types_in_seq)
dpctl.utils.validate_usm_type(usm_type)
if sycl_queue is None:
exec_q = seq_dev
alloc_q = seq_dev
else:
exec_q = dpctl.utils.get_execution_queue(
(
sycl_queue,
seq_dev,
)
)
alloc_q = sycl_queue
if dtype is None:
dtype = _map_to_device_dtype(seq_dt, alloc_q)
else:
_mapped_dt = _map_to_device_dtype(dtype, alloc_q)
if _mapped_dt != dtype:
raise ValueError(
f"Device {sycl_queue.sycl_device} "
f"does not support {dtype} natively."
)
dtype = _mapped_dt
if order in "KA":
order = "C"
if isinstance(exec_q, dpctl.SyclQueue):
res = dpt.empty(
seq_shape,
dtype=dtype,
usm_type=usm_type,
sycl_queue=alloc_q,
order=order,
)
ht_events = []
_device_copy_walker(seq_obj, res, ht_events)
dpctl.SyclEvent.wait_for(ht_events)
return res
else:
res = dpt.empty(
seq_shape,
dtype=dtype,
usm_type=usm_type,
sycl_queue=alloc_q,
order=order,
)
_copy_through_host_walker(seq_obj, res)
return res


def asarray(
obj,
dtype=None,
Expand All @@ -327,7 +417,9 @@ def asarray(
sycl_queue=None,
order="K",
):
"""
""" asarray(obj, dtype=None, copy=None, device=None, \
usm_type=None, sycl_queue=None, order="K")
Converts `obj` to :class:`dpctl.tensor.usm_ndarray`.
Args:
Expand All @@ -347,7 +439,7 @@ def asarray(
allocations if possible, but allowed to perform a copy otherwise.
Default: `None`.
order ("C","F","A","K", optional): memory layout of the output array.
Default: "C"
Default: "K"
device (optional): array API concept of device where the output array
is created. `device` can be `None`, a oneAPI filter selector string,
an instance of :class:`dpctl.SyclDevice` corresponding to a
Expand Down Expand Up @@ -452,7 +544,7 @@ def asarray(
raise ValueError(
"Converting Python sequence to usm_ndarray requires a copy"
)
_, _, devs = _array_info_sequence(obj)
seq_shape, seq_dt, devs = _array_info_sequence(obj)
if devs == _host_set:
return _asarray_from_numpy_ndarray(
np.asarray(obj, dtype=dtype, order=order),
Expand All @@ -461,7 +553,17 @@ def asarray(
sycl_queue=sycl_queue,
order=order,
)
# for sequences
elif len(devs) == 1:
return _asarray_from_seq(
obj,
seq_shape,
seq_dt,
list(devs)[0],
dtype=dtype,
usm_type=usm_type,
sycl_queue=sycl_queue,
order=order,
)
raise NotImplementedError(
"Converting Python sequences is not implemented"
)
Expand Down
5 changes: 5 additions & 0 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1629,3 +1629,8 @@ cdef api object UsmNDArray_MakeFromPtr(
offset=offset
)
return arr


def _is_object_with_buffer_protocol(o):
"Returns True if object support Python buffer protocol"
return _is_buffer(o)

0 comments on commit 7f0f48a

Please sign in to comment.