Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dask: Data.diff #350

Merged
merged 1 commit into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 19 additions & 35 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,32 +1579,35 @@ def _reset_mask_hardness(self):
"""
self.hardmask = self.hardmask

@daskified(_DASKIFIED_VERBOSE)
@_inplace_enabled(default=False)
def diff(self, axis=-1, n=1, inplace=False):
"""Calculate the n-th discrete difference along the given axis.

The first difference is given by ``x[i+1] - x[i]`` along the given
axis, higher differences are calculated by using `diff`
The first difference is given by ``x[i+1] - x[i]`` along the
given axis, higher differences are calculated by using `diff`
recursively.

The shape of the output is the same as the input except along the
given axis, where the dimension is smaller by *n*. The data type
of the output is the same as the type of the difference between
any two elements of the input.
The shape of the output is the same as the input except along
the given axis, where the dimension is smaller by *n*. The
data type of the output is the same as the type of the
difference between any two elements of the input.

.. versionadded:: 3.2.0

.. seealso:: `cumsum`, `sum`

:Parameters:

axis: int, optional
The axis along which the difference is taken. By default
the last axis is used. The *axis* argument is an integer
that selects the axis corresponding to the given position
in the list of axes of the data array.
The axis along which the difference is taken. By
default the last axis is used. The *axis* argument is
an integer that selects the axis corresponding to the
given position in the list of axes of the data array.

n: int, optional
The number of times values are differenced. If zero, the
input is returned as-is. By default *n* is ``1``.
The number of times values are differenced. If zero,
the input is returned as-is. By default *n* is ``1``.

{{inplace: `bool`, optional}}

Expand All @@ -1614,7 +1617,7 @@ def diff(self, axis=-1, n=1, inplace=False):
The n-th differences, or `None` if the operation was
in-place.

**Examples:**
**Examples**

>>> d = cf.Data(numpy.arange(12.).reshape(3, 4))
>>> d[1, 1] = 4.5
Expand Down Expand Up @@ -1658,28 +1661,9 @@ def diff(self, axis=-1, n=1, inplace=False):
"""
d = _inplace_enabled_define_and_cleanup(self)

if n == 0:
return d

out = d
for _ in range(n):
sections = out.section(axis, chunks=True)

# Diff each section
for key, data in sections.items():
output_array = np.diff(data.array, axis=axis)

sections[key] = type(self)(
output_array, units=self.Units, fill_value=self.fill_value
)

# Glue the sections back together again
out = self.__class__.reconstruct_sectioned_data(sections)

if inplace:
d.__dict__ = out.__dict__
else:
d = out
dx = self._get_dask()
dx = da.diff(dx, axis=axis, n=n)
d._set_dask(dx, reset_mask_hardness=False)

return d

Expand Down
4 changes: 1 addition & 3 deletions cf/test/test_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,6 @@ def test_Data_convolution_filter(self):
)
self.assertTrue((e.array == b).all())

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'")
def test_Data_diff(self):
if self.test_only and inspect.stack()[0][3] not in self.test_only:
return
Expand All @@ -668,8 +667,7 @@ def test_Data_diff(self):
self.assertTrue((d.array == a).all())

e = d.copy()
x = e.diff(inplace=True)
self.assertIsNone(x)
self.assertIsNone(e.diff(inplace=True))
self.assertTrue(e.equals(d.diff()))

for n in (0, 1, 2):
Expand Down