Skip to content

Commit

Permalink
Made Flags cdef class, used typed members, and typed functions, added…
Browse files Browse the repository at this point in the history
… __eq__

Support for `__eq__` allows to compare two instances of Flags objects and
compare Flags object to an integer.
  • Loading branch information
oleksandr-pavlyk authored and ndgrigorian committed Oct 10, 2022
1 parent 2d20e1e commit ec3f8e3
Showing 1 changed file with 33 additions and 16 deletions.
49 changes: 33 additions & 16 deletions dpctl/tensor/_flags.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,26 @@
# cython: language_level=3
# cython: linetrace=True

from libcpp cimport bool as cpp_bool

from dpctl.tensor._usmarray cimport (
USM_ARRAY_C_CONTIGUOUS,
USM_ARRAY_F_CONTIGUOUS,
USM_ARRAY_WRITEABLE,
usm_ndarray,
)


class Flags:
cdef cpp_bool _check_bit(int flag, int mask):
return (flag & mask) == mask


def __init__(self, arr, flags):
cdef class Flags:
"""Helper class to represent flags of :class:`dpctl.tensor.usm_ndarray`."""
cdef int flags_
cdef usm_ndarray arr_

def __cinit__(self, usm_ndarray arr, int flags):
self.arr_ = arr
self.flags_ = flags

Expand All @@ -37,32 +47,29 @@ class Flags:

@property
def c_contiguous(self):
return ((self.flags_ & USM_ARRAY_C_CONTIGUOUS)
== USM_ARRAY_C_CONTIGUOUS)
return _check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)

@property
def f_contiguous(self):
return ((self.flags_ & USM_ARRAY_F_CONTIGUOUS)
== USM_ARRAY_F_CONTIGUOUS)
return _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)

@property
def writable(self):
return False if ((self.flags & USM_ARRAY_WRITEABLE)
== USM_ARRAY_WRITEABLE) else True
return _check_bit(self.flags_, USM_ARRAY_WRITEABLE)

@property
def forc(self):
return True if (((self.flags_ & USM_ARRAY_F_CONTIGUOUS)
== USM_ARRAY_F_CONTIGUOUS)
or ((self.flags_ & USM_ARRAY_C_CONTIGUOUS)
== USM_ARRAY_C_CONTIGUOUS)) else False
return (
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
or _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
)

@property
def fnc(self):
return True if (((self.flags_ & USM_ARRAY_F_CONTIGUOUS)
== USM_ARRAY_F_CONTIGUOUS)
and not ((self.flags_ & USM_ARRAY_C_CONTIGUOUS)
== USM_ARRAY_C_CONTIGUOUS)) else False
return (
_check_bit(self.flags_, USM_ARRAY_C_CONTIGUOUS)
and _check_bit(self.flags_, USM_ARRAY_F_CONTIGUOUS)
)

@property
def contiguous(self):
Expand All @@ -83,3 +90,13 @@ class Flags:
for name in "C_CONTIGUOUS", "F_CONTIGUOUS", "WRITABLE":
out.append(" {} : {}".format(name, self[name]))
return '\n'.join(out)

def __eq__(self, other):
cdef Flags other_
if isinstance(other, self.__class__):
other_ = <Flags>other
return self.flags_ == other_.flags_
elif isinstance(other, int):
return self.flags_ == <int>other
else:
return False

0 comments on commit ec3f8e3

Please sign in to comment.