Skip to content

Commit

Permalink
Merge pull request #308 from sadielbartholomew/func-to-dask-built-ins…
Browse files Browse the repository at this point in the history
…-non-trig

LAMA to Dask: `func` -> built-in conversion for non-trig. methods
  • Loading branch information
sadielbartholomew authored Feb 2, 2022
2 parents 0206338 + 5956404 commit ae487ed
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 25 deletions.
56 changes: 41 additions & 15 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2824,6 +2824,7 @@ def can_compute(self, functions=None, log_levels=None, override=False):

@daskified(_DASKIFIED_VERBOSE)
@_deprecated_kwarg_check("i")
@_inplace_enabled(default=False)
def ceil(self, inplace=False, i=False):
"""The ceiling of the data, element-wise.
Expand Down Expand Up @@ -2855,7 +2856,10 @@ def ceil(self, inplace=False, i=False):
[-1. -1. -1. -1. 0. 1. 2. 2. 2.]
"""
return self.func(np.ceil, inplace=inplace)
d = _inplace_enabled_define_and_cleanup(self)
dx = d._get_dask()
d._set_dask(da.ceil(dx), reset_mask_hardness=False)
return d

@daskified(_DASKIFIED_VERBOSE)
@_inplace_enabled(default=False)
Expand Down Expand Up @@ -9213,13 +9217,14 @@ def exp(self, inplace=False, i=False):
if units and not units.isdimensionless:
raise ValueError(
"Can't take exponential of dimensional "
"quantities: {!r}".format(units)
f"quantities: {units!r}"
)

if d.Units:
d.Units = _units_1

d.func(np.exp, inplace=True)
dx = d._get_dask()
d._set_dask(da.exp(dx), reset_mask_hardness=False)

return d

Expand Down Expand Up @@ -10137,6 +10142,7 @@ def flatten(self, axes=None, inplace=False):

@daskified(_DASKIFIED_VERBOSE)
@_deprecated_kwarg_check("i")
@_inplace_enabled(default=False)
def floor(self, inplace=False, i=False):
"""Return the floor of the data array.
Expand All @@ -10163,7 +10169,10 @@ def floor(self, inplace=False, i=False):
[-2. -2. -2. -1. 0. 1. 1. 1. 1.]
"""
return self.func(np.floor, inplace=inplace)
d = _inplace_enabled_define_and_cleanup(self)
dx = d._get_dask()
d._set_dask(da.floor(dx), reset_mask_hardness=False)
return d

@_deprecated_kwarg_check("i")
def outerproduct(self, e, inplace=False, i=False):
Expand Down Expand Up @@ -11022,6 +11031,7 @@ def isclose(self, y, rtol=None, atol=None):

@daskified(_DASKIFIED_VERBOSE)
@_deprecated_kwarg_check("i")
@_inplace_enabled(default=False)
def rint(self, inplace=False, i=False):
"""Round the data to the nearest integer, element-wise.
Expand Down Expand Up @@ -11050,7 +11060,10 @@ def rint(self, inplace=False, i=False):
[-2. -2. -1. -1. 0. 1. 1. 2. 2.]
"""
return self.func(np.rint, inplace=inplace)
d = _inplace_enabled_define_and_cleanup(self)
dx = d._get_dask()
d._set_dask(da.rint(dx), reset_mask_hardness=False)
return d

def root_mean_square(
self,
Expand Down Expand Up @@ -11133,6 +11146,7 @@ def root_mean_square(

@daskified(_DASKIFIED_VERBOSE)
@_deprecated_kwarg_check("i")
@_inplace_enabled(default=False)
def round(self, decimals=0, inplace=False, i=False):
"""Evenly round elements of the data array to the given number
of decimals.
Expand Down Expand Up @@ -11176,7 +11190,10 @@ def round(self, decimals=0, inplace=False, i=False):
[-0., -0., -0., -0., 0., 0., 0., 0., 0.]
"""
return self.func(np.round, inplace=inplace, decimals=decimals)
d = _inplace_enabled_define_and_cleanup(self)
dx = d._get_dask()
d._set_dask(da.round(dx, decimals=decimals), reset_mask_hardness=False)
return d

def stats(
self,
Expand Down Expand Up @@ -11999,9 +12016,7 @@ def tanh(self, inplace=False):

return d

# TODOASK: daskified except in the case of arbitrary base (not e, 2 or 10)
# which requires `__itruediv__` to be daskified.
# @daskified(_DASKIFIED_VERBOSE)
@daskified(_DASKIFIED_VERBOSE)
@_deprecated_kwarg_check("i")
@_inplace_enabled(default=False)
def log(self, base=None, inplace=False, i=False):
Expand All @@ -12021,16 +12036,23 @@ def log(self, base=None, inplace=False, i=False):
"""
d = _inplace_enabled_define_and_cleanup(self)
dx = d._get_dask()

if base is None:
d.func(np.log, units=_units_1, inplace=True)
dx = da.log(dx)
elif base == 10:
d.func(np.log10, units=_units_1, inplace=True)
dx = da.log10(dx)
elif base == 2:
d.func(np.log2, units=_units_1, inplace=True)
dx = da.log2(dx)
else:
d.func(np.log, units=_units_1, inplace=True)
d /= np.log(base)
dx = da.log(dx)
dx /= da.log(base)

d._set_dask(dx, reset_mask_hardness=False)

d.override_units(
_units_1, inplace=True
) # all logarithm outputs are unitless

return d

Expand Down Expand Up @@ -12302,6 +12324,7 @@ def transpose(self, axes=None, inplace=False, i=False):

@daskified(_DASKIFIED_VERBOSE)
@_deprecated_kwarg_check("i")
@_inplace_enabled(default=False)
def trunc(self, inplace=False, i=False):
"""Return the truncated values of the data array.
Expand Down Expand Up @@ -12332,7 +12355,10 @@ def trunc(self, inplace=False, i=False):
[-1. -1. -1. -1. 0. 1. 1. 1. 1.]
"""
return self.func(np.trunc, inplace=inplace)
d = _inplace_enabled_define_and_cleanup(self)
dx = d._get_dask()
d._set_dask(da.trunc(dx), reset_mask_hardness=False)
return d

@classmethod
def empty(
Expand Down
21 changes: 11 additions & 10 deletions cf/test/test_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3391,16 +3391,17 @@ def test_Data_log(self):
self.assertEqual(d.shape, b.shape)

# Test an arbitrary base, using 4 (not a special managed case like 10)
# TODODASK: reinstate this assertion once mask property is
# daskified.
# a = np.array([[4, 16, 4**3.5], [0, 1, 0.25]])
# b = np.log(a) / np.log(4) # the numpy way, using log rules from school
# c = cf.Data(a, "s")
# d = c.log(base=4)
# self.assertTrue((d.array == b).all())
# self.assertEqual(d.shape, b.shape)

# Text values outside of the restricted domain for a log
a = np.array([[4, 16, 4 ** 3.5], [0, 1, 0.25]])
b = np.log(a) / np.log(4) # the numpy way, using log rules from school
c = cf.Data(a, "s")
d = c.log(base=4)
self.assertTrue((d.array == b).all())
self.assertEqual(d.shape, b.shape)

# Check units for general case
self.assertEqual(d.Units, cf.Units("1"))

# Text values outside of the restricted domain for a logarithm
a = np.array([0, -1, -2])
b = np.log(a)
c = cf.Data(a)
Expand Down

0 comments on commit ae487ed

Please sign in to comment.