Skip to content

Commit

Permalink
Added finfo_object subclass to np.finfo
Browse files Browse the repository at this point in the history
- Improves array API conformity
  • Loading branch information
ndgrigorian committed Mar 12, 2023
1 parent d7c2e3b commit 4012039
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,29 @@
)


class finfo_object(np.finfo):
"""
numpy.finfo subclass which returns Python floating-point scalars for
eps, max, min, and smallest_normal.
"""

def __init__(self, dtype):
_supported_dtype([dpt.dtype(dtype)])
super().__init__()

self.eps = float(self.eps)
self.max = float(self.max)
self.min = float(self.min)

@property
def smallest_normal(self):
return float(super().smallest_normal)

@property
def tiny(self):
return float(super().tiny)


def _broadcast_strides(X_shape, X_strides, res_ndim):
"""
Broadcasts strides to match the given dimensions;
Expand Down Expand Up @@ -495,8 +518,7 @@ def finfo(dtype):
"""
if isinstance(dtype, dpt.usm_ndarray):
raise TypeError("Expected dtype type, got {to}.")
_supported_dtype([dpt.dtype(dtype)])
return np.finfo(dtype)
return finfo_object(dtype)


def _supported_dtype(dtypes):
Expand Down

0 comments on commit 4012039

Please sign in to comment.