Skip to content

Commit

Permalink
remove uncalled raise
Browse files Browse the repository at this point in the history
  • Loading branch information
vtavana committed Dec 15, 2023
1 parent 9d24fbd commit a000a54
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
11 changes: 2 additions & 9 deletions dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,15 +489,8 @@ def get_result_array(a, out=None, casting="safe"):
raise ValueError(
f"Output array of shape {a.shape} is needed, got {out.shape}."
)
elif not isinstance(out, dpnp_array):
if isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)
else:
raise TypeError(
"Output array must be any of supported type, but got {}".format(
type(out)
)
)
elif isinstance(out, dpt.usm_ndarray):
out = dpnp_array._create_from_usm_ndarray(out)

dpnp.copyto(out, a, casting=casting)

Expand Down
5 changes: 4 additions & 1 deletion dpnp/dpnp_iface_nanfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,10 @@ def nanvar(
avg = dpnp.divide(avg, cnt, out=avg)

# Compute squared deviation from mean.
arr = dpnp.subtract(arr, avg)
if arr.dtype == avg.dtype:
arr = dpnp.subtract(arr, avg, out=arr)
else:
arr = dpnp.subtract(arr, avg)
dpnp.copyto(arr, 0.0, where=mask)
if dpnp.issubdtype(arr.dtype, dpnp.complexfloating):
sqr = dpnp.multiply(arr, arr.conj(), out=arr).real
Expand Down
10 changes: 8 additions & 2 deletions tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,35 @@ def test_max_min_out(func):
ia = dpnp.array(a)

np_res = getattr(numpy, func)(a, axis=0)
# output is dpnp array
dpnp_res = dpnp.array(numpy.empty_like(np_res))
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
assert_allclose(dpnp_res, np_res)

# output is usm array
dpnp_res = dpt.asarray(numpy.empty_like(np_res))
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)
assert_allclose(dpnp_res, np_res)

# output is numpy array -> Error
dpnp_res = numpy.empty_like(np_res)
with pytest.raises(TypeError):
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)

# output has incorrect shape -> Error
dpnp_res = dpnp.array(numpy.empty((2, 3)))
with pytest.raises(ValueError):
getattr(dpnp, func)(ia, axis=0, out=dpnp_res)


@pytest.mark.parametrize("func", ["max", "min"])
def test_max_min_NotImplemented(func):
def test_max_min_error(func):
ia = dpnp.arange(5)

# where is not supported
with pytest.raises(NotImplementedError):
getattr(dpnp, func)(ia, where=False)

# initial is not supported
with pytest.raises(NotImplementedError):
getattr(dpnp, func)(ia, initial=6)

Expand Down

0 comments on commit a000a54

Please sign in to comment.