Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block-mapped resample with the help of flox #1848

Merged
merged 42 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
30f28fe
Resample map helper
aulemahal Jul 19, 2024
88fa90b
merge generic-season
aulemahal Jul 19, 2024
065b4f3
New split_aux_coord func to remove aux coord and avoid dask comp on a…
aulemahal Jul 23, 2024
0c8c797
Ignore missing flox dep
aulemahal Jul 23, 2024
9c81ed8
Fix for bool mask
aulemahal Jul 23, 2024
764d67b
Fix aux coord mngmt in lazy indexing - fix doc split aux coord
aulemahal Jul 24, 2024
3d4c457
lower pin of flit
Jul 24, 2024
3adbf76
fix a fix that didnt fix what needed to be fixed
aulemahal Jul 26, 2024
befc53f
Merge branch 'resample-map' of github.com:Ouranosinc/xclim into resam…
aulemahal Jul 26, 2024
0bcbe5a
merge master and heat_spell
Aug 15, 2024
58fe1a0
Merge branch 'heat_spell' into resample-map
Aug 15, 2024
c39b47e
Merge branch 'heat_spell' into resample-map
Aug 15, 2024
9778739
Resample before spells
Aug 16, 2024
a7e5bde
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2024
b9d79b0
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
aulemahal Aug 16, 2024
b0fd634
Revert "Resample before spells"
aulemahal Aug 16, 2024
5a407aa
Merge branch 'heat_spell' into resample-map
Aug 18, 2024
167a387
merge main
Aug 28, 2024
38070c8
Merge branch 'main' into resample-map
Zeitsperre Aug 28, 2024
cef0e25
multi reducing
aulemahal Sep 5, 2024
df8c792
merge main
aulemahal Sep 6, 2024
90c14ef
fix deps - add minimal tests
aulemahal Sep 6, 2024
03d236e
add changelog
aulemahal Sep 6, 2024
1f3e82e
Dont test resample-map without flox
aulemahal Sep 6, 2024
b1dd2ac
Apply suggestions from code review
aulemahal Sep 6, 2024
8b717b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
881812b
Merge branch 'main' into resample-map
aulemahal Sep 27, 2024
561c54a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2024
54e5234
Skip auxiliary coords test
aulemahal Oct 1, 2024
ee2e352
add tests
aulemahal Oct 1, 2024
15fb7ed
Merge branch 'main' into resample-map
aulemahal Oct 3, 2024
6fcd4a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2024
9669035
Merge branch 'main' into resample-map
aulemahal Oct 3, 2024
50ffcff
Import callable
aulemahal Oct 4, 2024
d5d638a
fix test
aulemahal Oct 4, 2024
470f825
Merge branch 'main' into resample-map
aulemahal Oct 7, 2024
bb47d2b
Resample map for chill portions
aulemahal Oct 8, 2024
7817813
Merge branch 'main' into resample-map
aulemahal Oct 8, 2024
8096991
Merge branch 'main' into resample-map
Zeitsperre Oct 8, 2024
57da575
Fix docstring
aulemahal Oct 9, 2024
298386b
Merge branch 'main' into resample-map
aulemahal Oct 9, 2024
1a584d4
Merge branch 'main' into resample-map
aulemahal Oct 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ New features and enhancements
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
* New generic ``xclim.indices.generic.spell_mask`` that returns a mask of which days are part of a spell. Supports multivariate conditions and weights. Used in new generic index ``xclim.indices.generic.bivariate_spell_length_statistics`` that extends ``spell_length_statistics`` to two variables. (:pull:`1885`).
* Indicator parameters can now be assigned a new name, different from the argument name in the compute function. (:pull:`1885`).
* New global option ``resample_map_blocks`` to wrap all ``resample().map()`` code inside a ``xr.map_blocks`` to lower the number of dask tasks. Uses utility ``xclim.indices.helpers.resample_map`` and requires ``flox`` to ensure the chunking allows such block-mapping. Defaults to False. (:pull:`1848`).

Bug fixes
^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- click >=8.1
- dask >=2.6.0
- filelock >=3.14.0
- flox >= 0.9
- jsonpickle >=3.1.0
- numba >=0.54.1
- numpy >=1.23.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Zeitsperre marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ docs = [
"sphinxcontrib-bibtex",
"sphinxcontrib-svg2pdfconverter[Cairosvg]"
]
extras = ["fastnanquantile >=0.0.2", "POT >=0.9.4"]
extras = ["fastnanquantile >=0.0.2", "flox >=0.9", "POT >=0.9.4"]
all = ["xclim[dev]", "xclim[docs]", "xclim[extras]"]

[project.scripts]
Expand Down
45 changes: 45 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import pytest
import xarray as xr

from xclim.core.options import set_options
from xclim.core.units import convert_units_to
from xclim.core.utils import uses_dask
from xclim.indices import helpers
from xclim.testing.helpers import assert_lazy


@pytest.mark.parametrize("method,rtol", [("spencer", 5e3), ("simple", 1e2)])
Expand Down Expand Up @@ -132,3 +135,45 @@ def test_cosine_of_solar_zenith_angle():
]
)
np.testing.assert_allclose(cza[:4, :], exp_cza, rtol=1e-3)


def _test_function(da, op, dim):
return getattr(da, op)(dim)


@pytest.mark.parametrize(
["in_chunks", "exp_chunks"], [(60, 6 * (2,)), (30, 12 * (1,)), (-1, (12,))]
)
def test_resample_map(tas_series, in_chunks, exp_chunks):
pytest.importorskip("flox")
tas = tas_series(365 * [1]).chunk(time=in_chunks)
with assert_lazy:
out = helpers.resample_map(
tas, "time", "MS", lambda da: da.mean("time"), map_blocks=True
)
assert out.chunks[0] == exp_chunks
out.load() # Trigger compute to see if it actually works


def test_resample_map_dataset(tas_series, pr_series):
pytest.importorskip("flox")
tas = tas_series(3 * 365 * [1], start="2000-01-01").chunk(time=365)
pr = pr_series(3 * 365 * [1], start="2000-01-01").chunk(time=365)
ds = xr.Dataset({"pr": pr, "tas": tas})
with set_options(resample_map_blocks=True):
with assert_lazy:
out = helpers.resample_map(
ds,
"time",
"YS",
lambda da: da.mean("time"),
)
assert out.chunks["time"] == (1, 1, 1)
out.load()


def test_resample_map_passthrough(tas_series):
tas = tas_series(365 * [1])
with assert_lazy:
out = helpers.resample_map(tas, "time", "MS", lambda da: da.mean("time"))
assert not uses_dask(out)
19 changes: 15 additions & 4 deletions tests/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,13 +1428,24 @@ def test_1d(self, tasmax_series, thresh, window, op, expected):
def test_resampling_order(self, tasmax_series, resample_before_rl, expected):
a = np.zeros(365)
a[5:35] = 31
tx = tasmax_series(a + K2C)
tx = tasmax_series(a + K2C).chunk()

hsf = xci.hot_spell_frequency(
tx, resample_before_rl=resample_before_rl, freq="MS"
)
).load()
assert hsf[1] == expected

@pytest.importorskip("flox")
@pytest.mark.parametrize("resample_map", [True, False])
def test_resampling_map(self, tasmax_series, resample_map):
a = np.zeros(365)
a[5:35] = 31
tx = tasmax_series(a + K2C).chunk()

with set_options(resample_map_blocks=resample_map):
hsf = xci.hot_spell_frequency(tx, resample_before_rl=True, freq="MS").load()
assert hsf[1] == 1


class TestHotSpellMaxLength:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1708,10 +1719,10 @@ def test_run_start_at_0(self, pr_series):
def test_resampling_order(self, pr_series, resample_before_rl, expected):
a = np.zeros(365) + 10
a[5:35] = 0
pr = pr_series(a)
pr = pr_series(a).chunk()
out = xci.maximum_consecutive_dry_days(
pr, freq="ME", resample_before_rl=resample_before_rl
)
).load()
assert out[0] == expected


Expand Down
14 changes: 7 additions & 7 deletions xclim/core/indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
infer_kind_from_parameter,
is_percentile_dataarray,
load_module,
split_auxiliary_coordinates,
)

# Indicators registry
Expand Down Expand Up @@ -1446,13 +1447,12 @@ def _postprocess(self, outs, das, params):
# Reduce by or and broadcast to ensure the same length in time
# When indexing is used and there are no valid points in the last period, mask will not include it
mask = reduce(np.logical_or, miss)
if (
isinstance(mask, DataArray)
and "time" in mask.dims
and mask.time.size < outs[0].time.size
):
mask = mask.reindex(time=outs[0].time, fill_value=True)
outs = [out.where(np.logical_not(mask)) for out in outs]
if isinstance(mask, DataArray): # mask might be a bool in some cases
if "time" in mask.dims and mask.time.size < outs[0].time.size:
mask = mask.reindex(time=outs[0].time, fill_value=True)
# Remove any aux coord to avoid any unwanted dask computation in the alignment within "where"
mask, _ = split_auxiliary_coordinates(mask)
outs = [out.where(~mask) for out in outs]

return outs

Expand Down
6 changes: 6 additions & 0 deletions xclim/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SDBA_ENCODE_CF = "sdba_encode_cf"
KEEP_ATTRS = "keep_attrs"
AS_DATASET = "as_dataset"
MAP_BLOCKS = "resample_map_blocks"

MISSING_METHODS: dict[str, Callable] = {}

Expand All @@ -39,6 +40,7 @@
SDBA_ENCODE_CF: False,
KEEP_ATTRS: "xarray",
AS_DATASET: False,
MAP_BLOCKS: False,
}

_LOUDNESS_OPTIONS = frozenset(["log", "warn", "raise"])
Expand Down Expand Up @@ -71,6 +73,7 @@ def _valid_missing_options(mopts):
SDBA_ENCODE_CF: lambda opt: isinstance(opt, bool),
KEEP_ATTRS: _KEEP_ATTRS_OPTIONS.__contains__,
AS_DATASET: lambda opt: isinstance(opt, bool),
MAP_BLOCKS: lambda opt: isinstance(opt, bool),
coxipi marked this conversation as resolved.
Show resolved Hide resolved
}


Expand Down Expand Up @@ -185,6 +188,9 @@ class set_options:
Note that xarray's "default" is equivalent to False. Default: ``"xarray"``.
as_dataset : bool
If True, indicators output datasets. If False, they output DataArrays. Default :``False``.
resample_map_blocks: bool
If True, some indicators will wrap their resampling operations with `xr.map_blocks`, using :py:func:`xclim.indices.helpers.resample_map`.
This requires `flox` to be installed in order to ensure the chunking is appropriate.git

Examples
--------
Expand Down
41 changes: 41 additions & 0 deletions xclim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,3 +758,44 @@ def _chunk_like(*inputs: xr.DataArray | xr.Dataset, chunks: dict[str, int] | Non
da.chunk(**{d: c for d, c in chunks.items() if d in da.dims})
)
return tuple(outputs)


def split_auxiliary_coordinates(
obj: xr.DataArray | xr.Dataset,
) -> tuple[xr.DataArray | xr.Dataset, xr.Dataset]:
"""Split auxiliary coords from the dataset.

An auxiliary coordinate is a coordinate variable that does not define a dimension and thus is not necessarily needed for dataset alignment.
Any coordinate that has a name different than its dimension(s) is flagged as auxiliary. All scalar coordinates are flagged as auxiliary.

Parameters
----------
obj : DataArray or Dataset
Xarray object

Returns
-------
clean_obj : DataArray or Dataset
Same as `obj` but without any auxiliary coordinate.
aux_coords : Dataset
The auxiliary coordinates as a dataset. Might be empty.

Note
----
This is useful to circumvent xarray's alignment checks that will sometimes look the auxiliary coordinate's data, which can trigger
unwanted dask computations.

The auxiliary coordinates can be merged back with the dataset with
:py:meth:`xarray.Dataset.assign_coords` or :py:meth:`xarray.DataArray.assign_coords`.

>>> # xdoctest: +SKIP
>>> clean, aux = split_auxiliary_coordinates(ds)
>>> merged = clean.assign_coords(da.coords)
>>> merged.identical(ds) # True
"""
aux_crd_names = [
nm for nm, crd in obj.coords.items() if len(crd.dims) != 1 or crd.dims[0] != nm
coxipi marked this conversation as resolved.
Show resolved Hide resolved
]
aux_crd_ds = obj.coords.to_dataset()[aux_crd_names]
clean_obj = obj.drop_vars(aux_crd_names)
return clean_obj, aux_crd_ds
36 changes: 22 additions & 14 deletions xclim/indices/_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
spell_length_statistics,
threshold_count,
)
from xclim.indices.helpers import resample_map

# Frequencies : YS: year start, QS-DEC: seasons starting in december, MS: month start
# See http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases
Expand Down Expand Up @@ -1491,12 +1492,17 @@ def last_spring_frost(
thresh = convert_units_to(thresh, tasmin)
cond = compare(tasmin, op, thresh, constrain=("<", "<="))

out = cond.resample(time=freq).map(
out = resample_map(
cond,
"time",
freq,
rl.last_run_before_date,
window=window,
date=before_date,
dim="time",
coord="dayofyear",
map_kwargs=dict(
window=window,
date=before_date,
dim="time",
coord="dayofyear",
),
)
out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(tasmin))
return out
Expand Down Expand Up @@ -1662,11 +1668,12 @@ def first_snowfall(
thresh = convert_units_to(thresh, prsn, context="hydro")
cond = prsn >= thresh

out = cond.resample(time=freq).map(
out = resample_map(
cond,
"time",
freq,
rl.first_run,
window=1,
dim="time",
coord="dayofyear",
map_kwargs=dict(window=1, dim="time", coord="dayofyear"),
)
out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(prsn))
return out
Expand Down Expand Up @@ -1717,11 +1724,12 @@ def last_snowfall(
thresh = convert_units_to(thresh, prsn, context="hydro")
cond = prsn >= thresh

out = cond.resample(time=freq).map(
out = resample_map(
cond,
"time",
freq,
rl.last_run,
window=1,
dim="time",
coord="dayofyear",
map_kwargs=dict(window=1, dim="time", coord="dayofyear"),
)
out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(prsn))
return out
Expand Down Expand Up @@ -3097,7 +3105,7 @@ def _exceedance_date(grp):
never_reached_val = never_reached
return xarray.where((cumsum <= sum_thresh).all("time"), never_reached_val, out)

dded = c.clip(0).resample(time=freq).map(_exceedance_date)
dded = resample_map(c.clip(0), "time", freq, _exceedance_date)
dded = dded.assign_attrs(
units="", is_dayofyear=np.int32(1), calendar=get_calendar(tas)
)
Expand Down
Loading
Loading