Skip to content

Commit

Permalink
SUM, COUNT, PROPORTION: don't use dask on numpy arrays (SciTools#4905)
Browse files Browse the repository at this point in the history
* COUNT

* SUM

* tidy up

* increase dask min pin

* fix returned weights when masked

* typo

* docstrings

* whatsnew

* include reviewer in whatsnew
  • Loading branch information
rcomer authored and pp-mo committed Sep 26, 2022
1 parent e9c7f98 commit e8314f0
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 29 deletions.
16 changes: 11 additions & 5 deletions docs/src/whatsnew/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ This document explains the changes made to Iris for this release
===========

#. `@ESadek-MO`_ edited :func:`~iris.io.expand_filespecs` to allow expansion of
non-existing paths, and added expansion functionality to :func:`~iris.io.save`.
(:issue:`4772`, :pull:`4913`)
non-existing paths, and added expansion functionality to :func:`~iris.io.save`.
(:issue:`4772`, :pull:`4913`)


🐛 Bugs Fixed
=============

#. N/A
#. `@rcomer`_ and `@pp-mo`_ (reviewer) factored masking into the returned
sum-of-weights calculation from :obj:`~iris.analysis.SUM`. (:pull:`4905`)


💣 Incompatible Changes
Expand All @@ -51,7 +52,9 @@ This document explains the changes made to Iris for this release
🚀 Performance Enhancements
===========================

#. N/A
#. `@rcomer`_ and `@pp-mo`_ (reviewer) increased aggregation speed for
:obj:`~iris.analysis.SUM`, :obj:`~iris.analysis.COUNT` and
:obj:`~iris.analysis.PROPORTION` on real data. (:pull:`4905`)


🔥 Deprecations
Expand All @@ -63,7 +66,8 @@ This document explains the changes made to Iris for this release
🔗 Dependencies
===============

#. N/A
#. `@rcomer`_ introduced the ``dask >=2.26`` minimum pin, so that Iris can benefit
from Dask's support for `NEP13`_ and `NEP18`_. (:pull:`4905`)


📚 Documentation
Expand All @@ -89,3 +93,5 @@ This document explains the changes made to Iris for this release
Whatsnew resources in alphabetical order:
.. _NEP13: https://numpy.org/neps/nep-0013-ufunc-overrides.html
.. _NEP18: https://numpy.org/neps/nep-0018-array-function-protocol.html
53 changes: 33 additions & 20 deletions lib/iris/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,18 +1499,21 @@ def _weighted_percentile(
return result


@_build_dask_mdtol_function
def _lazy_count(array, **kwargs):
array = iris._lazy_data.as_lazy_data(array)
def _count(array, **kwargs):
"""
Counts the number of points along the axis that satisfy the condition
specified by ``function``. Uses Dask's support for NEP13/18 to work as
either a lazy or a real function.
"""
func = kwargs.pop("function", None)
if not callable(func):
emsg = "function must be a callable. Got {}."
raise TypeError(emsg.format(type(func)))
return da.sum(func(array), **kwargs)
return np.sum(func(array), **kwargs)


def _proportion(array, function, axis, **kwargs):
count = iris._lazy_data.non_lazy(_lazy_count)
# if the incoming array is masked use that to count the total number of
# values
if ma.isMaskedArray(array):
Expand All @@ -1521,7 +1524,7 @@ def _proportion(array, function, axis, **kwargs):
# case pass the array shape instead of the mask:
total_non_masked = array.shape[axis]
else:
total_non_masked = count(
total_non_masked = _count(
array.mask, axis=axis, function=np.logical_not, **kwargs
)
total_non_masked = ma.masked_equal(total_non_masked, 0)
Expand All @@ -1534,7 +1537,7 @@ def _proportion(array, function, axis, **kwargs):
# a dtype for its data that is different to the dtype of the fill-value,
# which can cause issues outside this function.
# Reference - tests/unit/analyis/test_PROPORTION.py Test_masked.test_ma
numerator = count(array, axis=axis, function=function, **kwargs)
numerator = _count(array, axis=axis, function=function, **kwargs)
result = ma.asarray(numerator / total_non_masked)

return result
Expand Down Expand Up @@ -1604,23 +1607,33 @@ def _lazy_rms(array, axis, **kwargs):
return da.sqrt(da.mean(array**2, axis=axis, **kwargs))


@_build_dask_mdtol_function
def _lazy_sum(array, **kwargs):
array = iris._lazy_data.as_lazy_data(array)
# weighted or scaled sum
def _sum(array, **kwargs):
"""
Weighted or scaled sum. Uses Dask's support for NEP13/18 to work as either
a lazy or a real function.
"""
axis_in = kwargs.get("axis", None)
weights_in = kwargs.pop("weights", None)
returned_in = kwargs.pop("returned", False)
if weights_in is not None:
wsum = da.sum(weights_in * array, **kwargs)
wsum = np.sum(weights_in * array, **kwargs)
else:
wsum = da.sum(array, **kwargs)
wsum = np.sum(array, **kwargs)
if returned_in:
al = da if iris._lazy_data.is_lazy_data(array) else np
if weights_in is None:
weights = iris._lazy_data.as_lazy_data(np.ones_like(array))
weights = al.ones_like(array)
if al is da:
# Dask version of ones_like does not preserve masks. See dask#9301.
weights = da.ma.masked_array(
weights, da.ma.getmaskarray(array)
)
else:
weights = weights_in
rvalue = (wsum, da.sum(weights, axis=axis_in))
weights = al.ma.masked_array(
weights_in, mask=al.ma.getmaskarray(array)
)
rvalue = (wsum, np.sum(weights, axis=axis_in))
else:
rvalue = wsum
return rvalue
Expand Down Expand Up @@ -1740,9 +1753,9 @@ def interp_order(length):
#
COUNT = Aggregator(
"count",
iris._lazy_data.non_lazy(_lazy_count),
_count,
units_func=lambda units: 1,
lazy_func=_lazy_count,
lazy_func=_build_dask_mdtol_function(_count),
)
"""
An :class:`~iris.analysis.Aggregator` instance that counts the number
Expand Down Expand Up @@ -2114,8 +2127,8 @@ def interp_order(length):

SUM = WeightedAggregator(
"sum",
iris._lazy_data.non_lazy(_lazy_sum),
lazy_func=_build_dask_mdtol_function(_lazy_sum),
_sum,
lazy_func=_build_dask_mdtol_function(_sum),
)
"""
An :class:`~iris.analysis.Aggregator` instance that calculates
Expand Down
22 changes: 22 additions & 0 deletions lib/iris/tests/unit/analysis/test_SUM.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# importing anything else.
import iris.tests as tests # isort:skip

import dask.array as da
import numpy as np
import numpy.ma as ma

Expand Down Expand Up @@ -91,6 +92,16 @@ def test_weights_and_returned(self):
self.assertArrayEqual(data, [14, 9, 11, 13, 15])
self.assertArrayEqual(weights, [4, 2, 2, 2, 2])

def test_masked_weights_and_returned(self):
array = ma.array(
self.cube_2d.data, mask=[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0]]
)
data, weights = SUM.aggregate(
array, axis=0, weights=self.weights, returned=True
)
self.assertArrayEqual(data, [14, 9, 8, 4, 15])
self.assertArrayEqual(weights, [4, 2, 1, 1, 2])


class Test_lazy_weights_and_returned(tests.IrisTest):
def setUp(self):
Expand Down Expand Up @@ -128,6 +139,17 @@ def test_weights_and_returned(self):
self.assertArrayEqual(lazy_data.compute(), [14, 9, 11, 13, 15])
self.assertArrayEqual(weights, [4, 2, 2, 2, 2])

def test_masked_weights_and_returned(self):
array = da.ma.masked_array(
self.cube_2d.lazy_data(), mask=[[0, 0, 1, 0, 0], [0, 0, 0, 1, 0]]
)
lazy_data, weights = SUM.lazy_aggregate(
array, axis=0, weights=self.weights, returned=True
)
self.assertTrue(is_lazy_data(lazy_data))
self.assertArrayEqual(lazy_data.compute(), [14, 9, 8, 4, 15])
self.assertArrayEqual(weights, [4, 2, 1, 1, 2])


class Test_aggregate_shape(tests.IrisTest):
def test(self):
Expand Down
2 changes: 1 addition & 1 deletion requirements/ci/py310.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- cartopy >=0.20
- cf-units >=3.1
- cftime >=1.5
- dask-core >=2
- dask-core >=2.26
- matplotlib
- netcdf4
- numpy >=1.19
Expand Down
2 changes: 1 addition & 1 deletion requirements/ci/py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- cartopy >=0.20
- cf-units >=3.1
- cftime >=1.5
- dask-core >=2
- dask-core >=2.26
- matplotlib
- netcdf4
- numpy >=1.19
Expand Down
2 changes: 1 addition & 1 deletion requirements/ci/py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies:
- cartopy >=0.20
- cf-units >=3.1
- cftime >=1.5
- dask-core >=2
- dask-core >=2.26
- matplotlib
- netcdf4
- numpy >=1.19
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ install_requires =
cartopy>=0.20
cf-units>=3.1
cftime>=1.5.0
dask[array]>=2
dask[array]>=2.26
matplotlib
netcdf4
numpy>=1.19
Expand Down

0 comments on commit e8314f0

Please sign in to comment.