diff --git a/dpnp/dpnp_iface.py b/dpnp/dpnp_iface.py index 74d4aaed6a8..215509c1fc3 100644 --- a/dpnp/dpnp_iface.py +++ b/dpnp/dpnp_iface.py @@ -489,15 +489,8 @@ def get_result_array(a, out=None, casting="safe"): raise ValueError( f"Output array of shape {a.shape} is needed, got {out.shape}." ) - elif not isinstance(out, dpnp_array): - if isinstance(out, dpt.usm_ndarray): - out = dpnp_array._create_from_usm_ndarray(out) - else: - raise TypeError( - "Output array must be any of supported type, but got {}".format( - type(out) - ) - ) + elif isinstance(out, dpt.usm_ndarray): + out = dpnp_array._create_from_usm_ndarray(out) dpnp.copyto(out, a, casting=casting) diff --git a/dpnp/dpnp_iface_nanfunctions.py b/dpnp/dpnp_iface_nanfunctions.py index b6a492f4f1c..966a2c9a578 100644 --- a/dpnp/dpnp_iface_nanfunctions.py +++ b/dpnp/dpnp_iface_nanfunctions.py @@ -384,7 +384,10 @@ def nanvar( avg = dpnp.divide(avg, cnt, out=avg) # Compute squared deviation from mean. - arr = dpnp.subtract(arr, avg) + if arr.dtype == avg.dtype: + arr = dpnp.subtract(arr, avg, out=arr) + else: + arr = dpnp.subtract(arr, avg) dpnp.copyto(arr, 0.0, where=mask) if dpnp.issubdtype(arr.dtype, dpnp.complexfloating): sqr = dpnp.multiply(arr, arr.conj(), out=arr).real diff --git a/tests/test_statistics.py b/tests/test_statistics.py index b0cc01ee70e..3caaaf9c805 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -66,29 +66,35 @@ def test_max_min_out(func): ia = dpnp.array(a) np_res = getattr(numpy, func)(a, axis=0) + # output is dpnp array dpnp_res = dpnp.array(numpy.empty_like(np_res)) getattr(dpnp, func)(ia, axis=0, out=dpnp_res) assert_allclose(dpnp_res, np_res) + # output is usm array dpnp_res = dpt.asarray(numpy.empty_like(np_res)) getattr(dpnp, func)(ia, axis=0, out=dpnp_res) assert_allclose(dpnp_res, np_res) + # output is numpy array -> Error dpnp_res = numpy.empty_like(np_res) with pytest.raises(TypeError): getattr(dpnp, func)(ia, axis=0, out=dpnp_res) + # output has incorrect shape -> Error dpnp_res = dpnp.array(numpy.empty((2, 3))) with pytest.raises(ValueError): getattr(dpnp, func)(ia, axis=0, out=dpnp_res) @pytest.mark.parametrize("func", ["max", "min"]) -def test_max_min_NotImplemented(func): +def test_max_min_error(func): ia = dpnp.arange(5) - + # where is not supported with pytest.raises(NotImplementedError): getattr(dpnp, func)(ia, where=False) + + # initial is not supported with pytest.raises(NotImplementedError): getattr(dpnp, func)(ia, initial=6)