Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dask: Data.flatten #333

Merged
merged 3 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 38 additions & 57 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9989,10 +9989,10 @@ def flat(self, ignore_masked=True):
else:
yield cf_masked

@daskified(_DASKIFIED_VERBOSE)
@_inplace_enabled(default=False)
def flatten(self, axes=None, inplace=False):
"""Flatten axes of the data.

TODODASK - check against daask flatten behaviour
"""Flatten specified axes of the data.

Any subset of the axes may be flattened.

Expand All @@ -10004,21 +10004,16 @@ def flatten(self, axes=None, inplace=False):

.. versionadded:: 3.0.2

.. seealso:: `compressed`, `insert_dimension`, `flip`, `swapaxes`,
`transpose`
.. seealso:: `compressed`, `flat`, `insert_dimension`, `flip`,
`swapaxes`, `transpose`

:Parameters:

axes: (sequence of) int or str, optional
Select the axes. By default all axes are flattened. The
*axes* argument may be one, or a sequence, of:

* An internal axis identifier. Selects this axis.

* An integer. Selects the axis corresponding to the given
position in the list of axes of the data array.

No axes are flattened if *axes* is an empty sequence.
axes: (sequence of) `int`
Select the axes to be flattened. By default all axes
are flattened. Each axis is identified by its integer
position. No axes are flattened if *axes* is an empty
sequence.

{{inplace: `bool`, optional}}

Expand All @@ -10030,7 +10025,8 @@ def flatten(self, axes=None, inplace=False):

**Examples**

>>> d = cf.Data(numpy.arange(24).reshape(1, 2, 3, 4))
>>> import numpy as np
>>> d = cf.Data(np.arange(24).reshape(1, 2, 3, 4))
>>> d
<CF Data(1, 2, 3, 4): [[[[0, ..., 23]]]]>
>>> print(d.array)
Expand Down Expand Up @@ -10078,67 +10074,52 @@ def flatten(self, axes=None, inplace=False):
[15 19 23]]]

"""
if inplace:
d = self
else:
d = self.copy()
d = _inplace_enabled_define_and_cleanup(self)

ndim = self._ndim
ndim = d.ndim
if not ndim:
if axes or axes == 0:
raise ValueError(
"Can't flatten: Can't remove an axis from "
"scalar {}".format(self.__class__.__name__)
"Can't flatten: Can't remove axes from "
f"scalar {self.__class__.__name__}"
)

if inplace:
d = None
return d

shape = list(d._shape)

# Note that it is important that the first axis in the list is
# the left-most flattened axis
if axes is None:
axes = list(range(ndim))
else:
axes = sorted(d._parse_axes(axes))

n_axes = len(axes)
if n_axes <= 1:
if inplace:
d = None
return d

new_shape = [n for i, n in enumerate(shape) if i not in axes]
new_shape.insert(axes[0], np.prod([shape[i] for i in axes]))

out = d.empty(new_shape, dtype=d.dtype, units=d.Units, chunk=True)
out.hardmask = False

n_non_flattened_axes = ndim - n_axes

for key, data in d.section(axes).items():
flattened_array = data.array.flatten()
size = flattened_array.size

first_None_index = key.index(None)

indices = [i for i in key if i is not None]
indices.insert(first_None_index, slice(0, size))

shape = [1] * n_non_flattened_axes
shape.insert(first_None_index, size)

out[tuple(indices)] = flattened_array.reshape(shape)
dx = d._get_dask()

out.hardmask = True
# It is important that the first axis in the list is the
# left-most flattened axis.
#
# E.g. if the shape is (10, 20, 30, 40, 50, 60) and the axes
# to be flattened are [2, 4], then the data must be
# transposed with order [0, 1, 2, 4, 3, 5]
order = [i for i in range(ndim) if i not in axes]
order[axes[0] : axes[0]] = axes
dx = dx.transpose(order)

# Find the flattened shape.
#
# E.g. if the *transposed* shape is (10, 20, 30, 50, 40, 60)
# and *transposed* axes [2, 3] are to be flattened then
# the new shape will be (10, 20, 1500, 40, 60)
shape = d.shape
new_shape = [n for i, n in enumerate(shape) if i not in axes]
new_shape.insert(axes[0], reduce(mul, [shape[i] for i in axes], 1))

if inplace:
d.__dict__ = out.__dict__
out = None
dx = dx.reshape(new_shape)
d._set_dask(dx, reset_mask_hardness=False)

return out
return d

@daskified(_DASKIFIED_VERBOSE)
@_deprecated_kwarg_check("i")
Expand Down
1 change: 0 additions & 1 deletion cf/test/test_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,6 @@ def test_Data_cumsum(self):
e = d.cumsum(axis=i, masked_as_zero=False)
self.assertTrue(cf.functions._numpy_allclose(e.array, b))

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attribute '_ndim'")
def test_Data_flatten(self):
if self.test_only and inspect.stack()[0][3] not in self.test_only:
return
Expand Down