Skip to content

Commit

Permalink
Make weights handling simpler by explicitly calling _Weights constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
schlunma committed Jun 19, 2023
1 parent 0e5b625 commit fb466ba
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 64 deletions.
29 changes: 0 additions & 29 deletions lib/iris/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,35 +1288,6 @@ def __init__(self, weights, cube):
self.array = derived_array
self.units = derived_units

@classmethod
def get_updated_kwargs(cls, kwargs, cube):
"""Get kwargs and weights units updated with `weights` information.
Args:
* kwargs (dict):
Keyword arguments that will be updated if a `weights` keyword is
present which is not ``None``.
* cube (Cube):
Input cube for aggregation. If weights is given as :obj:`str`, try
to extract a cell measure with the corresponding name from this
cube. Otherwise, this argument is ignored.
Returns:
tuple. A tuple containing the updated keyword arguments and the
weights units.
"""
kwargs = dict(kwargs)
weights_units = None

if kwargs.get("weights") is not None:
weights = cls(kwargs["weights"], cube)
kwargs["weights"] = weights.array
weights_units = weights.units

return (kwargs, weights_units)


def create_weighted_aggregator_fn(aggregator_fn, axis, **kwargs):
"""Return an aggregator function that can explicitly handle weights.
Expand Down
21 changes: 15 additions & 6 deletions lib/iris/cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -3836,7 +3836,10 @@ def collapsed(self, coords, aggregator, **kwargs):
"""
# Update weights kwargs (if necessary) to handle different types of
# weights
(kwargs, weights_units) = _Weights.get_updated_kwargs(kwargs, self)
weights_info = None
if kwargs.get("weights") is not None:
weights_info = _Weights(kwargs["weights"], self)
kwargs["weights"] = weights_info.array

# Convert any coordinate names to coordinates
coords = self._as_list_of_coords(coords)
Expand Down Expand Up @@ -3984,7 +3987,7 @@ def collapsed(self, coords, aggregator, **kwargs):
collapsed_cube,
coords,
axis=collapse_axis,
_weights_units=weights_units,
_weights_units=getattr(weights_info, "units", None),
**kwargs,
)
result = aggregator.post_process(
Expand Down Expand Up @@ -4078,7 +4081,10 @@ def aggregated_by(
"""
# Update weights kwargs (if necessary) to handle different types of
# weights
(kwargs, weights_units) = _Weights.get_updated_kwargs(kwargs, self)
weights_info = None
if kwargs.get("weights") is not None:
weights_info = _Weights(kwargs["weights"], self)
kwargs["weights"] = weights_info.array

groupby_coords = []
dimension_to_groupby = None
Expand Down Expand Up @@ -4275,7 +4281,7 @@ def aggregated_by(
aggregateby_cube,
groupby_coords,
aggregate=True,
_weights_units=weights_units,
_weights_units=getattr(weights_info, "units", None),
**kwargs,
)
# Replace the appropriate coordinates within the aggregate-by cube.
Expand Down Expand Up @@ -4415,7 +4421,10 @@ def rolling_window(self, coord, aggregator, window, **kwargs):
"""
# Update weights kwargs (if necessary) to handle different types of
# weights
(kwargs, weights_units) = _Weights.get_updated_kwargs(kwargs, self)
weights_info = None
if kwargs.get("weights") is not None:
weights_info = _Weights(kwargs["weights"], self)
kwargs["weights"] = weights_info.array

coord = self._as_list_of_coords(coord)[0]

Expand Down Expand Up @@ -4502,7 +4511,7 @@ def rolling_window(self, coord, aggregator, window, **kwargs):
new_cube,
[coord],
action="with a rolling window of length %s over" % window,
_weights_units=weights_units,
_weights_units=getattr(weights_info, "units", None),
**kwargs,
)
# and perform the data transformation, generating weights first if
Expand Down
36 changes: 7 additions & 29 deletions lib/iris/tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,35 +1814,13 @@ def test_init_with_cell_measure(self):
np.testing.assert_array_equal(weights.array, self.data)
assert weights.units == "m2"

def test_get_updated_kwargs_no_weights(self):
kwargs = {"test": [1, 2, 3]}
(new_kwargs, weights_units) = _Weights.get_updated_kwargs(
kwargs, self.cube
)
assert new_kwargs is not kwargs
assert new_kwargs == {"test": [1, 2, 3]}
assert weights_units is None

def test_get_updated_kwargs_weights_none(self):
kwargs = {"test": [1, 2, 3], "weights": None}
(new_kwargs, weights_units) = _Weights.get_updated_kwargs(
kwargs, self.cube
)
assert new_kwargs is not kwargs
assert new_kwargs == {"test": [1, 2, 3], "weights": None}
assert weights_units is None

def test_get_updated_kwargs_weights(self):
kwargs = {"test": [1, 2, 3], "weights": self.data}
(new_kwargs, weights_units) = _Weights.get_updated_kwargs(
kwargs, self.cube
)
assert new_kwargs is not kwargs
assert len(new_kwargs) == 2
assert new_kwargs["test"] == [1, 2, 3]
assert isinstance(new_kwargs["weights"], self.target_type)
assert new_kwargs["weights"] is self.data
assert weights_units == "1"
def test_init_with_list(self):
list_in = [0, 1, 2]
weights = _Weights(list_in, self.cube)
assert isinstance(weights.array, list)
assert isinstance(weights.units, cf_units.Unit)
assert weights.array is list_in
assert weights.units == "1"


class TestWeightsLazy(TestWeights):
Expand Down

0 comments on commit fb466ba

Please sign in to comment.