diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 65884951de..57c5225a3e 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -1014,14 +1014,30 @@ cdef usm_ndarray _real_view(usm_ndarray ary): """ View into real parts of a complex type array """ - cdef usm_ndarray r = ary._clone() + cdef int r_typenum_ = -1 + cdef usm_ndarray r = None + cdef Py_ssize_t offset_elems = 0 + if (ary.typenum_ == UAR_CFLOAT): - r.typenum_ = UAR_FLOAT + r_typenum_ = UAR_FLOAT elif (ary.typenum_ == UAR_CDOUBLE): - r.typenum_ = UAR_DOUBLE + r_typenum_ = UAR_DOUBLE else: raise InternalUSMArrayError( "_real_view call on array of non-complex type.") + + offset_elems = ary.get_offset() * 2 + r = usm_ndarray.__new__( + usm_ndarray, + _make_int_tuple(ary.nd_, ary.shape_), + dtype=_make_typestr(r_typenum_), + strides=tuple(2 * si for si in ary.strides), + buffer=ary.base_, + offset=offset_elems, + order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F') + ) + r.flags_ = ary.flags_ + r.array_namespace_ = ary.array_namespace_ return r @@ -1029,16 +1045,31 @@ cdef usm_ndarray _imag_view(usm_ndarray ary): """ View into imaginary parts of a complex type array """ - cdef usm_ndarray r = ary._clone() + cdef int r_typenum_ = -1 + cdef usm_ndarray r = None + cdef Py_ssize_t offset_elems = 0 + if (ary.typenum_ == UAR_CFLOAT): - r.typenum_ = UAR_FLOAT + r_typenum_ = UAR_FLOAT elif (ary.typenum_ == UAR_CDOUBLE): - r.typenum_ = UAR_DOUBLE + r_typenum_ = UAR_DOUBLE else: raise InternalUSMArrayError( - "_real_view call on array of non-complex type.") + "_imag_view call on array of non-complex type.") + # displace pointer to imaginary part - r.data_ = r.data_ + type_bytesize(r.typenum_) + offset_elems = 2 * ary.get_offset() + 1 + r = usm_ndarray.__new__( + usm_ndarray, + _make_int_tuple(ary.nd_, ary.shape_), + dtype=_make_typestr(r_typenum_), + strides=tuple(2 * si for si in ary.strides), + buffer=ary.base_, + offset=offset_elems, + order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F') + ) + r.flags_ = ary.flags_ + r.array_namespace_ = ary.array_namespace_ return r @@ -1054,7 +1085,8 @@ cdef usm_ndarray _transpose(usm_ndarray ary): _make_reversed_int_tuple(ary.nd_, ary.strides_) if (ary.strides_) else None), buffer=ary.base_, - order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C') + order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'), + offset=ary.get_offset() ) r.flags_ |= (ary.flags_ & USM_ARRAY_WRITEABLE) return r diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index fc78d66bfe..b5fab57566 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -841,3 +841,25 @@ def test_reshape(): dpt.reshape(Z, Z.shape, order="invalid") W = dpt.reshape(Z, (-1,), order="C") assert W.shape == (Z.size,) + + +def test_transpose(): + n, m = 2, 3 + X = dpt.usm_ndarray((n, m), "f4") + Xnp = np.arange(n * m, dtype="f4").reshape((n, m)) + X[:] = Xnp + assert np.array_equal(dpt.to_numpy(X.T), Xnp.T) + assert np.array_equal(dpt.to_numpy(X[1:].T), Xnp[1:].T) + + +def test_real_imag_views(): + n, m = 2, 3 + X = dpt.usm_ndarray((n, m), "c8") + Xnp_r = np.arange(n * m, dtype="f4").reshape((n, m)) + Xnp_i = np.arange(n * m, 2 * n * m, dtype="f4").reshape((n, m)) + Xnp = Xnp_r + 1j * Xnp_i + X[:] = Xnp + assert np.array_equal(dpt.to_numpy(X.real), Xnp.real) + assert np.array_equal(dpt.to_numpy(X.imag), Xnp.imag) + assert np.array_equal(dpt.to_numpy(X[1:].real), Xnp[1:].real) + assert np.array_equal(dpt.to_numpy(X[1:].imag), Xnp[1:].imag)