Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dask: Dask.digitize #312

Merged
merged 5 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 35 additions & 53 deletions cf/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,13 +1656,16 @@ def dumps(self):

return json_dumps(d, default=convert_to_builtin_type)

@daskified(_DASKIFIED_VERBOSE)
@_inplace_enabled(default=False)
davidhassell marked this conversation as resolved.
Show resolved Hide resolved
def digitize(
self,
bins,
upper=False,
open_ends=False,
closed_ends=None,
return_bins=False,
inplace=False,
):
"""Return the indices of the bins to which each value belongs.

Expand Down Expand Up @@ -1747,6 +1750,8 @@ def digitize(
return_bins: `bool`, optional
If True then also return the bins in their 2-d form.

{{inplace: `bool`, optional}}

:Returns:

`Data`, [`Data`]
Expand All @@ -1755,7 +1760,7 @@ def digitize(
If *return_bins* is True then also return the bins in
their 2-d form.

**Examples:**
**Examples**

>>> d = cf.Data(numpy.arange(12).reshape(3, 4))
[[ 0 1 2 3]
Expand Down Expand Up @@ -1811,9 +1816,9 @@ def digitize(
[ 1 1 1 --]]

"""
out = self.copy()
d = _inplace_enabled_define_and_cleanup(self)

org_units = self.Units
org_units = d.Units

bin_units = getattr(bins, "Units", None)

Expand All @@ -1830,12 +1835,16 @@ def digitize(
else:
bin_units = org_units

bins = np.asanyarray(bins)
# Get bins as a numpy array
if isinstance(bins, np.ndarray):
bins = bins.copy()
else:
bins = np.asanyarray(bins)

if bins.ndim > 2:
raise ValueError(
"The 'bins' parameter must be scalar, 1-d or 2-d"
"Got: {!r}".format(bins)
"The 'bins' parameter must be scalar, 1-d or 2-d. "
f"Got: {bins!r}"
davidhassell marked this conversation as resolved.
Show resolved Hide resolved
)

two_d_bins = None
Expand All @@ -1848,7 +1857,7 @@ def digitize(
if bins.shape[1] != 2:
raise ValueError(
"The second dimension of the 'bins' parameter must "
"have size 2. Got: {!r}".format(bins)
f"have size 2. Got: {bins!r}"
)

bins.sort(axis=1)
Expand All @@ -1858,11 +1867,9 @@ def digitize(
for i, (u, l) in enumerate(zip(bins[:-1, 1], bins[1:, 0])):
if u > l:
raise ValueError(
"Overlapping bins: {}, {}".format(
tuple(bins[i]), tuple(bins[i + i])
)
f"Overlapping bins: "
f"{tuple(bins[i])}, {tuple(bins[i + i])}"
)
# --- End: for

two_d_bins = bins
bins = np.unique(bins)
Expand Down Expand Up @@ -1900,8 +1907,8 @@ def digitize(
"scalar."
)

mx = self.max().datum()
mn = self.min().datum()
mx = d.max().datum()
mn = d.min().datum()
bins = np.linspace(mn, mx, int(bins) + 1, dtype=float)

delete_bins = []
Expand All @@ -1913,7 +1920,8 @@ def digitize(
"Can't set open_ends=True when closed_ends is True."
)

bins = bins.astype(float, copy=True)
if bins.dtype.kind != "f":
bins = bins.astype(float, copy=False)

epsilon = np.finfo(float).eps
ndim = bins.ndim
Expand All @@ -1923,53 +1931,27 @@ def digitize(
else:
mx = bins[(-1,) * ndim]
bins[(-1,) * ndim] += abs(mx) * epsilon
# --- End: if

if not open_ends:
delete_bins.insert(0, 0)
delete_bins.append(bins.size)

if return_bins and two_d_bins is None:
x = np.empty((bins.size - 1, 2), dtype=bins.dtype)
x[:, 0] = bins[:-1]
x[:, 1] = bins[1:]
two_d_bins = x

config = out.partition_configuration(readonly=True)

for partition in out.partitions.matrix.flat:
partition.open(config)
array = partition.array

mask = None
if np.ma.isMA(array):
mask = array.mask.copy()

array = np.digitize(array, bins, right=upper)

if delete_bins:
for n, d in enumerate(delete_bins):
d -= n
array = np.ma.where(array == d, np.ma.masked, array)
array = np.ma.where(array > d, array - 1, array)
# --- End: if

if mask is not None:
array = np.ma.where(mask, np.ma.masked, array)

partition.subarray = array
partition.Units = _units_None

partition.close()

out.dtype = int

out.override_units(_units_None, inplace=True)
# Digitise the array
dx = d._get_dask()
dx = da.digitize(dx, bins, right=upper)
d._set_dask(dx, reset_mask_hardness=True)
d.override_units(_units_None, inplace=True)

if return_bins:
return out, type(self)(two_d_bins, units=bin_units)
if two_d_bins is None:
two_d_bins = np.empty((bins.size - 1, 2), dtype=bins.dtype)
two_d_bins[:, 0] = bins[:-1]
two_d_bins[:, 1] = bins[1:]

return out
two_d_bins = type(self)(two_d_bins, units=bin_units)
return d, two_d_bins

return d

def median(
self,
Expand Down
39 changes: 30 additions & 9 deletions cf/test/test_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,6 @@ def test_Data__init__dtype_mask(self):
self.assertTrue((d.array == a).all())
self.assertTrue((d.mask.array == np.ma.getmaskarray(a)).all())

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attr. 'partition_configuration'")
def test_Data_digitize(self):
if self.test_only and inspect.stack()[0][3] not in self.test_only:
return
Expand All @@ -829,15 +828,37 @@ def test_Data_digitize(self):
b = np.digitize(a, [2, 6, 10, 50, 100], right=upper)

self.assertTrue((e.array == b).all())

e.where(
cf.set([e.minimum(), e.maximum()]),
cf.masked,
e - 1,
inplace=True,
self.assertTrue(
(np.ma.getmask(e.array) == np.ma.getmask(b)).all()
)
f = d.digitize(bins, upper=upper)
self.assertTrue(e.equals(f, verbose=2))

# TODODASK: Reinstate the following test when
# __sub__, minimum, and maximum have
# been daskified

# e.where(
# cf.set([e.minimum(), e.maximum()]),
# cf.masked,
# e - 1,
# inplace=True,
# )
# f = d.digitize(bins, upper=upper)
# self.assertTrue(e.equals(f, verbose=2))

# Check returned bins
bins = [2, 6, 10, 50, 100]
e, b = d.digitize(bins, return_bins=True)
self.assertTrue(
(b.array == [[2, 6], [6, 10], [10, 50], [50, 100]]).all()
)
self.assertTrue(b.Units == d.Units)

# Check digitized units
self.assertTrue(e.Units == cf.Units(None))

# Check inplace
self.assertIsNone(d.digitize(bins, inplace=True))
self.assertTrue(d.equals(e))

@unittest.skipIf(TEST_DASKIFIED_ONLY, "no attribute '_ndim'")
def test_Data_cumsum(self):
Expand Down