Skip to content

Commit

Permalink
Merge pull request #1110 from IntelPython/array-api-cleanup
Browse files Browse the repository at this point in the history
Improvements to array API conformity
  • Loading branch information
ndgrigorian authored Mar 13, 2023
2 parents 5bfc097 + 4012039 commit cecfdaa
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 58 deletions.
37 changes: 26 additions & 11 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,12 @@ def asarray(


def empty(
sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None
shape,
dtype=None,
order="C",
device=None,
usm_type="device",
sycl_queue=None,
):
"""
Creates `usm_ndarray` from uninitialized USM allocation.
Expand Down Expand Up @@ -509,7 +514,7 @@ def empty(
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand Down Expand Up @@ -650,7 +655,12 @@ def arange(


def zeros(
sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None
shape,
dtype=None,
order="C",
device=None,
usm_type="device",
sycl_queue=None,
):
"""
Creates `usm_ndarray` with zero elements.
Expand Down Expand Up @@ -687,7 +697,7 @@ def zeros(
dtype = _get_dtype(dtype, sycl_queue)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand All @@ -698,7 +708,12 @@ def zeros(


def ones(
sh, dtype=None, order="C", device=None, usm_type="device", sycl_queue=None
shape,
dtype=None,
order="C",
device=None,
usm_type="device",
sycl_queue=None,
):
"""
Creates `usm_ndarray` with elements of one.
Expand Down Expand Up @@ -734,7 +749,7 @@ def ones(
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
dtype = _get_dtype(dtype, sycl_queue)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand All @@ -746,7 +761,7 @@ def ones(


def full(
sh,
shape,
fill_value,
dtype=None,
order="C",
Expand Down Expand Up @@ -805,14 +820,14 @@ def full(
usm_type=usm_type,
sycl_queue=sycl_queue,
)
return dpt.copy(dpt.broadcast_to(X, sh), order=order)
return dpt.copy(dpt.broadcast_to(X, shape), order=order)

sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
usm_type = usm_type if usm_type is not None else "device"
fill_value_type = type(fill_value)
dtype = _get_dtype(dtype, sycl_queue, ref_type=fill_value_type)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand Down Expand Up @@ -872,11 +887,11 @@ def empty_like(
if device is None and sycl_queue is None:
device = x.device
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
sh = x.shape
shape = x.shape
dtype = dpt.dtype(dtype)
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
res = dpt.usm_ndarray(
sh,
shape,
dtype=dtype,
buffer=usm_type,
order=order,
Expand Down
78 changes: 50 additions & 28 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,29 @@
)


class finfo_object(np.finfo):
"""
numpy.finfo subclass which returns Python floating-point scalars for
eps, max, min, and smallest_normal.
"""

def __init__(self, dtype):
_supported_dtype([dpt.dtype(dtype)])
super().__init__()

self.eps = float(self.eps)
self.max = float(self.max)
self.min = float(self.min)

@property
def smallest_normal(self):
return float(super().smallest_normal)

@property
def tiny(self):
return float(super().tiny)


def _broadcast_strides(X_shape, X_strides, res_ndim):
"""
Broadcasts strides to match the given dimensions;
Expand Down Expand Up @@ -122,46 +145,46 @@ def permute_dims(X, axes):
)


def expand_dims(X, axes):
def expand_dims(X, axis):
"""
expand_dims(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
expand_dims(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray
Expands the shape of an array by inserting a new axis (dimension)
of size one at the position specified by axes; returns a view, if possible,
of size one at the position specified by axis; returns a view, if possible,
a copy otherwise with the number of dimensions increased.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
if not isinstance(axes, (tuple, list)):
axes = (axes,)
if not isinstance(axis, (tuple, list)):
axis = (axis,)

out_ndim = len(axes) + X.ndim
axes = normalize_axis_tuple(axes, out_ndim)
out_ndim = len(axis) + X.ndim
axis = normalize_axis_tuple(axis, out_ndim)

shape_it = iter(X.shape)
shape = tuple(1 if ax in axes else next(shape_it) for ax in range(out_ndim))
shape = tuple(1 if ax in axis else next(shape_it) for ax in range(out_ndim))

return dpt.reshape(X, shape)


def squeeze(X, axes=None):
def squeeze(X, axis=None):
"""
squeeze(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
squeeze(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray
Removes singleton dimensions (axes) from X; returns a view, if possible,
Removes singleton dimensions (axis) from X; returns a view, if possible,
a copy otherwise, but with all or a subset of the dimensions
of length 1 removed.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
X_shape = X.shape
if axes is not None:
if not isinstance(axes, (tuple, list)):
axes = (axes,)
axes = normalize_axis_tuple(axes, X.ndim if X.ndim != 0 else X.ndim + 1)
if axis is not None:
if not isinstance(axis, (tuple, list)):
axis = (axis,)
axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1)
new_shape = []
for i, x in enumerate(X_shape):
if i not in axes:
if i not in axis:
new_shape.append(x)
else:
if x != 1:
Expand Down Expand Up @@ -222,9 +245,9 @@ def broadcast_arrays(*args):
return [broadcast_to(X, shape) for X in args]


def flip(X, axes=None):
def flip(X, axis=None):
"""
flip(X: usm_ndarray, axes: int or tuple or list) -> usm_ndarray
flip(X: usm_ndarray, axis: int or tuple or list) -> usm_ndarray
Reverses the order of elements in an array along the given axis.
The shape of the array is preserved, but the elements are reordered;
Expand All @@ -233,20 +256,20 @@ def flip(X, axes=None):
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
X_ndim = X.ndim
if axes is None:
if axis is None:
indexer = (np.s_[::-1],) * X_ndim
else:
axes = normalize_axis_tuple(axes, X_ndim)
axis = normalize_axis_tuple(axis, X_ndim)
indexer = tuple(
np.s_[::-1] if i in axes else np.s_[:] for i in range(X.ndim)
np.s_[::-1] if i in axis else np.s_[:] for i in range(X.ndim)
)
return X[indexer]


def roll(X, shift, axes=None):
def roll(X, shift, axis=None):
"""
roll(X: usm_ndarray, shift: int or tuple or list,\
axes: int or tuple or list) -> usm_ndarray
axis: int or tuple or list) -> usm_ndarray
Rolls array elements along a specified axis.
Array elements that roll beyond the last position are re-introduced
Expand All @@ -257,7 +280,7 @@ def roll(X, shift, axes=None):
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
if axes is None:
if axis is None:
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue
)
Expand All @@ -266,8 +289,8 @@ def roll(X, shift, axes=None):
)
hev.wait()
return res
axes = normalize_axis_tuple(axes, X.ndim, allow_duplicate=True)
broadcasted = np.broadcast(shift, axes)
axis = normalize_axis_tuple(axis, X.ndim, allow_duplicate=True)
broadcasted = np.broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError("'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(X.ndim)}
Expand Down Expand Up @@ -495,8 +518,7 @@ def finfo(dtype):
"""
if isinstance(dtype, dpt.usm_ndarray):
raise TypeError("Expected dtype type, got {to}.")
_supported_dtype([dpt.dtype(dtype)])
return np.finfo(dtype)
return finfo_object(dtype)


def _supported_dtype(dtypes):
Expand Down
28 changes: 14 additions & 14 deletions dpctl/tensor/_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,17 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
return new_sts if valid else None


def reshape(X, newshape, order="C", copy=None):
def reshape(X, shape, order="C", copy=None):
"""
reshape(X: usm_ndarray, newshape: tuple, order="C") -> usm_ndarray
reshape(X: usm_ndarray, shape: tuple, order="C") -> usm_ndarray
Reshapes given usm_ndarray into new shape. Returns a view, if possible,
a copy otherwise. Memory layout of the copy is controlled by order keyword.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError
if not isinstance(newshape, (list, tuple)):
newshape = (newshape,)
if not isinstance(shape, (list, tuple)):
shape = (shape,)
if order in "cfCF":
order = order.upper()
else:
Expand All @@ -97,9 +97,9 @@ def reshape(X, newshape, order="C", copy=None):
f"Keyword 'copy' not recognized. Expecting True, False, "
f"or None, got {copy}"
)
newshape = [operator.index(d) for d in newshape]
shape = [operator.index(d) for d in shape]
negative_ones_count = 0
for nshi in newshape:
for nshi in shape:
if nshi == -1:
negative_ones_count = negative_ones_count + 1
if (nshi < -1) or negative_ones_count > 1:
Expand All @@ -108,14 +108,14 @@ def reshape(X, newshape, order="C", copy=None):
"value which can only be -1"
)
if negative_ones_count:
v = X.size // (-np.prod(newshape))
newshape = [v if d == -1 else d for d in newshape]
if X.size != np.prod(newshape):
raise ValueError(f"Can not reshape into {newshape}")
v = X.size // (-np.prod(shape))
shape = [v if d == -1 else d for d in shape]
if X.size != np.prod(shape):
raise ValueError(f"Can not reshape into {shape}")
if X.size:
newsts = reshaped_strides(X.shape, X.strides, newshape, order=order)
newsts = reshaped_strides(X.shape, X.strides, shape, order=order)
else:
newsts = (1,) * len(newshape)
newsts = (1,) * len(shape)
copy_required = newsts is None
if copy_required and (copy is False):
raise ValueError(
Expand All @@ -141,11 +141,11 @@ def reshape(X, newshape, order="C", copy=None):
flat_res[i], X[np.unravel_index(i, X.shape, order=order)]
)
return dpt.usm_ndarray(
tuple(newshape), dtype=X.dtype, buffer=flat_res, order=order
tuple(shape), dtype=X.dtype, buffer=flat_res, order=order
)
# can form a view
return dpt.usm_ndarray(
newshape,
shape,
dtype=X.dtype,
buffer=X,
strides=tuple(newsts),
Expand Down
10 changes: 5 additions & 5 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def test_incompatible_shapes_raise_valueerror(shapes):
assert_broadcast_arrays_raise(input_shapes[::-1])


def test_flip_axes_incorrect():
def test_flip_axis_incorrect():
try:
q = dpctl.SyclQueue()
except dpctl.SyclQueueCreationError:
Expand All @@ -492,10 +492,10 @@ def test_flip_axes_incorrect():
X_np = np.ones((4, 4))
X = dpt.asarray(X_np, sycl_queue=q)

pytest.raises(np.AxisError, dpt.flip, dpt.asarray(np.ones(4)), axes=1)
pytest.raises(np.AxisError, dpt.flip, X, axes=2)
pytest.raises(np.AxisError, dpt.flip, X, axes=-3)
pytest.raises(np.AxisError, dpt.flip, X, axes=(0, 3))
pytest.raises(np.AxisError, dpt.flip, dpt.asarray(np.ones(4)), axis=1)
pytest.raises(np.AxisError, dpt.flip, X, axis=2)
pytest.raises(np.AxisError, dpt.flip, X, axis=-3)
pytest.raises(np.AxisError, dpt.flip, X, axis=(0, 3))


def test_flip_0d():
Expand Down

0 comments on commit cecfdaa

Please sign in to comment.