diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index c49862d70f..87c7179dcc 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -66,15 +66,20 @@ def contract_iter2(shape, strides1, strides2): def _has_memory_overlap(x1, x2): - m1 = dpm.as_usm_memory(x1) - m2 = dpm.as_usm_memory(x2) - if m1.sycl_device == m2.sycl_device: - p1_beg = m1._pointer - p1_end = p1_beg + m1.nbytes - p2_beg = m2._pointer - p2_end = p2_beg + m2.nbytes - return p1_beg > p2_end or p2_beg < p1_end + if x1.size and x2.size: + m1 = dpm.as_usm_memory(x1) + m2 = dpm.as_usm_memory(x2) + # can only overlap if bound to the same context + if m1.sycl_context == m2.sycl_context: + p1_beg = m1._pointer + p1_end = p1_beg + m1.nbytes + p2_beg = m2._pointer + p2_end = p2_beg + m2.nbytes + return p1_beg > p2_end or p2_beg < p1_end + else: + return False else: + # zero element array do not overlap anything return False @@ -193,6 +198,9 @@ def copy_same_dtype(dst, src): if dst.dtype != src.dtype: raise ValueError + if dst.size == 0: + return + # check that memory regions do not overlap if _has_memory_overlap(dst, src): tmp = _copy_to_numpy(src) diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index e8ee79a391..c9e22ece41 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -600,6 +600,8 @@ def test_setitem_same_dtype(dtype, src_usm_type, dst_usm_type): R2 = np.broadcast_to(Xnp[0], R1.shape) assert R1.shape == R2.shape assert np.allclose(R1, R2) + Zusm_empty = Zusm_1d[0:0] + Zusm_empty[Ellipsis] = Zusm_3d[0, 0, 0:0] @pytest.mark.parametrize(