Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed instances of writeable to writable, added docstrings to _flags.pyx #928

Merged
merged 1 commit into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dpctl/memory/_memory.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ cdef class _Memory:
self.memory_ptr = other_buf.p
self.nbytes = other_buf.nbytes
self.queue = other_buf.queue
# self.writeable = other_buf.writeable
# self.writable = other_buf.writable
self.refobj = other
else:
raise ValueError(
Expand Down Expand Up @@ -333,7 +333,7 @@ cdef class _Memory:
def __get__(self):
cdef dict iface = {
"data": (<size_t>(<void *>self.memory_ptr),
True), # bool(self.writeable)),
True), # bool(self.writable)),
"shape": (self.nbytes,),
"strides": None,
"typestr": "|u1",
Expand Down
10 changes: 5 additions & 5 deletions dpctl/memory/_sycl_usm_array_interface_utils.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ cdef class _USMBufferData:
`__sycl_usm_array_interface__` dictionary
"""
cdef DPCTLSyclUSMRef p
cdef int writeable
cdef int writable
cdef object dt
cdef Py_ssize_t itemsize
cdef Py_ssize_t nbytes
Expand All @@ -140,7 +140,7 @@ cdef class _USMBufferData:
cdef size_t arr_data_ptr = 0
cdef DPCTLSyclUSMRef memRef = NULL
cdef Py_ssize_t itemsize = -1
cdef int writeable = -1
cdef int writable = -1
cdef int nd = -1
cdef DPCTLSyclQueueRef QRef = NULL
cdef object dt
Expand All @@ -156,9 +156,9 @@ cdef class _USMBufferData:
if not ary_data_tuple or len(ary_data_tuple) != 2:
raise ValueError("__sycl_usm_array_interface__ is malformed:"
" 'data' field is required, and must be a tuple"
" (usm_pointer, is_writeable_boolean).")
" (usm_pointer, is_writable_boolean).")
arr_data_ptr = <size_t>ary_data_tuple[0]
writeable = 1 if ary_data_tuple[1] else 0
writable = 1 if ary_data_tuple[1] else 0
# Check that memory and syclobj are consistent:
# (USM pointer is bound to this sycl context)
memRef = <DPCTLSyclUSMRef>arr_data_ptr
Expand Down Expand Up @@ -207,7 +207,7 @@ cdef class _USMBufferData:
buf = _USMBufferData.__new__(_USMBufferData)
buf.p = <DPCTLSyclUSMRef>(
arr_data_ptr + (<Py_ssize_t>min_disp) * itemsize)
buf.writeable = writeable
buf.writable = writable
buf.itemsize = itemsize
buf.nbytes = <Py_ssize_t> nbytes

Expand Down
43 changes: 39 additions & 4 deletions dpctl/tensor/_flags.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ from libcpp cimport bool as cpp_bool
from dpctl.tensor._usmarray cimport (
USM_ARRAY_C_CONTIGUOUS,
USM_ARRAY_F_CONTIGUOUS,
USM_ARRAY_WRITEABLE,
USM_ARRAY_WRITABLE,
usm_ndarray,
)

Expand All @@ -33,7 +33,10 @@ cdef cpp_bool _check_bit(int flag, int mask):


cdef class Flags:
"""Helper class to represent flags of :class:`dpctl.tensor.usm_ndarray`."""
"""
Helper class to represent memory layout flags of
:class:`dpctl.tensor.usm_ndarray`.
"""
cdef int flags_
cdef usm_ndarray arr_

Expand All @@ -43,51 +46,83 @@ cdef class Flags:

@property
def flags(self):
"""
Integer representation of the memory layout flags of
:class:`dpctl.tensor.usm_ndarray` instance.
"""
return self.flags_

@property
def c_contiguous(self):
"""
True if the memory layout of the
:class:`dpctl.tensor.usm_ndarray` instance is C-contiguous.
"""
return _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)

@property
def f_contiguous(self):
"""
True if the memory layout of the
:class:`dpctl.tensor.usm_ndarray` instance is F-contiguous.
"""
return _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)

@property
def writable(self):
return _check_bit(self.flags_, USM_ARRAY_WRITEABLE)
"""
True if :class:`dpctl.tensor.usm_ndarray` instance is writable.
"""
return _check_bit(self.flags_, USM_ARRAY_WRITABLE)

@property
def fc(self):
"""
True if the memory layout of the :class:`dpctl.tensor.usm_ndarray`
instance is C-contiguous and F-contiguous.
"""
return (
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
and _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
)

@property
def forc(self):
"""
True if the memory layout of the :class:`dpctl.tensor.usm_ndarray`
instance is C-contiguous or F-contiguous.
"""
return (
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
or _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
)

@property
def fnc(self):
"""
True if the memory layout of the :class:`dpctl.tensor.usm_ndarray`
instance is F-contiguous and not C-contiguous.
"""
return (
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
and not _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
)

@property
def contiguous(self):
"""
True if the memory layout of the :class:`dpctl.tensor.usm_ndarray`
instance is C-contiguous and F-contiguous.
Equivalent to `forc.`
"""
return self.forc

def __getitem__(self, name):
if name in ["C_CONTIGUOUS", "C"]:
return self.c_contiguous
elif name in ["F_CONTIGUOUS", "F"]:
return self.f_contiguous
elif name == "WRITABLE":
elif name in ["WRITABLE", "W"]:
return self.writable
elif name == "FC":
return self.fc
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tensor/_stride_utils.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ cdef int ERROR_UNEXPECTED_STRIDES = 3

cdef int USM_ARRAY_C_CONTIGUOUS = 1
cdef int USM_ARRAY_F_CONTIGUOUS = 2
cdef int USM_ARRAY_WRITEABLE = 4
cdef int USM_ARRAY_WRITABLE = 4


cdef Py_ssize_t shape_to_elem_count(int nd, Py_ssize_t *shape_arr):
Expand Down
2 changes: 1 addition & 1 deletion dpctl/tensor/_usmarray.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ cimport dpctl

cdef public api int USM_ARRAY_C_CONTIGUOUS
cdef public api int USM_ARRAY_F_CONTIGUOUS
cdef public api int USM_ARRAY_WRITEABLE
cdef public api int USM_ARRAY_WRITABLE

cdef public api int UAR_BOOL
cdef public api int UAR_BYTE
Expand Down
8 changes: 4 additions & 4 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ cdef class usm_ndarray:
ary_iface = self.base_.__sycl_usm_array_interface__
mem_ptr = <char *>(<size_t> ary_iface['data'][0])
ary_ptr = <char *>(<size_t> self.data_)
ro_flag = False if (self.flags_ & USM_ARRAY_WRITEABLE) else True
ro_flag = False if (self.flags_ & USM_ARRAY_WRITABLE) else True
ary_iface['data'] = (<size_t> mem_ptr, ro_flag)
ary_iface['shape'] = self.shape
if (self.strides_):
Expand Down Expand Up @@ -637,7 +637,7 @@ cdef class usm_ndarray:
buffer=self.base_,
offset=_meta[2]
)
res.flags_ |= (self.flags_ & USM_ARRAY_WRITEABLE)
res.flags_ |= (self.flags_ & USM_ARRAY_WRITABLE)
res.array_namespace_ = self.array_namespace_
return res

Expand Down Expand Up @@ -1175,7 +1175,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
offset=ary.get_offset()
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITEABLE)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
return r


Expand All @@ -1192,7 +1192,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):
order=('F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'C'),
offset=ary.get_offset()
)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITEABLE)
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
return r


Expand Down
2 changes: 1 addition & 1 deletion dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def test_pyx_capi_check_constants():
assert cc_flag > 0 and 0 == (cc_flag & (cc_flag - 1))
fc_flag = _pyx_capi_int(X, "USM_ARRAY_F_CONTIGUOUS")
assert fc_flag > 0 and 0 == (fc_flag & (fc_flag - 1))
w_flag = _pyx_capi_int(X, "USM_ARRAY_WRITEABLE")
w_flag = _pyx_capi_int(X, "USM_ARRAY_WRITABLE")
assert w_flag > 0 and 0 == (w_flag & (w_flag - 1))

bool_typenum = _pyx_capi_int(X, "UAR_BOOL")
Expand Down