Skip to content

Commit

Permalink
Merge pull request #383 from davidhassell/dask-compressed
Browse files Browse the repository at this point in the history
dask: `Data.compressed`
  • Loading branch information
davidhassell authored Apr 20, 2022
2 parents 2fc1595 + 0d29da5 commit cdea366
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 77 deletions.
66 changes: 16 additions & 50 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def __len__(self):
"""
dx = self._get_dask()
if math.isnan(dx.size):
logger.warning("Computing data len: Performance may be degraded")
logger.debug("Computing data len: Performance may be degraded")
dx.compute_chunk_sizes()

return len(dx)
Expand Down Expand Up @@ -5474,7 +5474,7 @@ def nbytes(self):
"""
dx = self.get_dask(copy=False)
if math.isnan(dx.size):
logger.warning("Computing nbytes: Performance may be degraded")
logger.debug("Computing data nbytes: Performance may be degraded")
dx.compute_chunk_sizes()

return dx.nbytes
Expand Down Expand Up @@ -5541,7 +5541,7 @@ def shape(self):
"""
dx = self.get_dask(copy=False)
if math.isnan(dx.size):
logger.warning("Computing data shape: Performance may be degraded")
logger.debug("Computing data shape: Performance may be degraded")
dx.compute_chunk_sizes()

return dx.shape
Expand Down Expand Up @@ -5582,7 +5582,7 @@ def size(self):
dx = self.get_dask(copy=False)
size = dx.size
if math.isnan(size):
logger.warning("Computing data size: Performance may be degraded")
logger.debug("Computing data size: Performance may be degraded")
dx.compute_chunk_sizes()
size = dx.size

Expand Down Expand Up @@ -7814,7 +7814,7 @@ def compressed(self, inplace=False):
**Examples**
>>> d = cf.Data(numpy.arange(12).reshape(3, 4))
>>> d = cf.Data(numpy.arange(12).reshape(3, 4), 'm')
>>> print(d.array)
[[ 0 1 2 3]
[ 4 5 6 7]
Expand All @@ -7831,58 +7831,24 @@ def compressed(self, inplace=False):
[ 0 1 2 3 4 6 7 8 9 10]
>>> d = cf.Data(9)
>>> print(d.array)
9
>>> print(d.compressed().array)
9
[9]
"""
d = _inplace_enabled_define_and_cleanup(self)

ndim = d.ndim

if ndim != 1:
d.flatten(inplace=True)

n_non_missing = d.count()
if n_non_missing == d.size:
return d

comp = self.empty(
shape=(n_non_missing,), dtype=self.dtype, units=self.Units
dx = d._get_dask()
dx = da.blockwise(
np.ma.compressed,
"i",
dx.ravel(),
"i",
adjust_chunks={"i": lambda n: np.nan},
dtype=dx.dtype,
meta=np.array((), dtype=dx.dtype),
)

# Find the number of array elements that fit in one chunk
n = int(cf_chunksize() // (self.dtype.itemsize + 1.0))

# Loop around each chunk's worth of elements and assign the
# non-missing values to the compressed data
i = 0
start = 0
for _ in range(1 + d.size // n):
if i >= d.size:
break

array = d[i : i + n].array
if np.ma.isMA(array):
array = array.compressed()

size = array.size
if size >= 1:
end = start + size
comp[start:end] = array
start = end

i += n

if not d.ndim:
comp.squeeze(inplace=True)

if inplace:
d.__dict__ = comp.__dict__
else:
d = comp

d._set_dask(dx, reset_mask_hardness=False)
return d

@daskified(_DASKIFIED_VERBOSE)
Expand Down
53 changes: 26 additions & 27 deletions cf/test/test_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,40 +675,39 @@ def test_Data_diff(self):
self.assertTrue((a_diff == d_diff).all())
self.assertTrue((a_diff.mask == d_diff.mask).all())

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attribute '_ndim'")
def test_Data_compressed(self):
if self.test_only and inspect.stack()[0][3] not in self.test_only:
return

a = np.ma.arange(12).reshape(3, 4)
d = cf.Data(a, "m", chunks=2)
self.assertIsNone(d.compressed(inplace=True))
self.assertEqual(d.shape, (a.size,))
self.assertEqual(d.Units, cf.Units("m"))
self.assertEqual(d.dtype, a.dtype)

d = cf.Data(a)
self.assertTrue((d.array == a).all())
self.assertTrue((a.compressed() == d.compressed()).all())

e = d.copy()
x = e.compressed(inplace=True)
self.assertIsNone(x)
self.assertTrue(e.equals(d.compressed()))

a[1, 1] = np.ma.masked
a[2, 3] = np.ma.masked
d = cf.Data(a, "m", chunks=2)
self.assertTrue((d.compressed().array == a.compressed()).all())

d = cf.Data(a)
self.assertTrue((d.array == a).all())
self.assertTrue((d.mask.array == a.mask).all())
self.assertTrue((a.compressed() == d.compressed()).all())
a[2] = np.ma.masked
d = cf.Data(a, "m", chunks=2)
self.assertTrue((d.compressed().array == a.compressed()).all())

e = d.copy()
x = e.compressed(inplace=True)
self.assertIsNone(x)
self.assertTrue(e.equals(d.compressed()))
a[...] = np.ma.masked
d = cf.Data(a, "m", chunks=2)
e = d.compressed()
self.assertEqual(e.shape, (0,))
self.assertTrue((e.array == a.compressed()).all())

d = cf.Data(self.a, "km")
self.assertTrue((self.a.flatten() == d.compressed()).all())
# Scalar arrays
a = np.ma.array(9)
d = cf.Data(a, "m")
e = d.compressed()
self.assertEqual(e.shape, (1,))
self.assertTrue((e.array == a.compressed()).all())

d = cf.Data(self.ma, "km")
self.assertTrue((self.ma.compressed() == d.compressed()).all())
a = np.ma.array(9, mask=True)
d = cf.Data(a, "m")
e = d.compressed()
self.assertEqual(e.shape, (0,))
self.assertTrue((e.array == a.compressed()).all())

@unittest.skipIf(TEST_DASKIFIED_ONLY, "Needs __eq__")
def test_Data_stats(self):
Expand Down

0 comments on commit cdea366

Please sign in to comment.