diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 559a05ba3f0..9369f79a5df 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2666,11 +2666,13 @@ def chunk( sizes along that dimension will not be updated; non-dask arrays will be converted into dask arrays with a single block. + Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted. + Parameters ---------- - chunks : int, tuple of int, "auto" or mapping of hashable to int, optional + chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or - ``{"x": 5, "y": 5}``. + ``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``. name_prefix : str, default: "xarray-" Prefix for the name of any new dask arrays. token : str, optional @@ -2705,6 +2707,8 @@ def chunk( xarray.unify_chunks dask.array.from_array """ + from xarray.core.dataarray import DataArray + if chunks is None and not chunks_kwargs: warnings.warn( "None value for 'chunks' is deprecated. " @@ -2730,6 +2734,39 @@ def chunk( f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}" ) + def _resolve_frequency(name: Hashable, freq: str) -> tuple[int]: + variable = self._variables.get(name, None) + if variable is None: + raise ValueError( + f"Cannot chunk by frequency string {freq!r} for virtual variables." + ) + elif not _contains_datetime_like_objects(variable): + raise ValueError( + f"chunks={freq!r} only supported for datetime variables. " + f"Received variable {name!r} with dtype {variable.dtype!r} instead." + ) + + chunks = tuple( + DataArray( + np.ones(variable.shape, dtype=int), + dims=(name,), + coords={name: variable}, + ) + # TODO: This could be generalized to `freq` being a `Resampler` object, + # and using `groupby` instead of `resample` + .resample({name: freq}) + .sum() + .data.tolist() + ) + return chunks + + chunks_mapping_ints = { + name: ( + _resolve_frequency(name, chunks) if isinstance(chunks, str) else chunks + ) + for name, chunks in chunks_mapping.items() + } + chunkmanager = guess_chunkmanager(chunked_array_type) if from_array_kwargs is None: from_array_kwargs = {} @@ -2738,7 +2775,7 @@ def chunk( k: _maybe_chunk( k, v, - chunks_mapping, + chunks_mapping_ints, token, lock, name_prefix, diff --git a/xarray/core/types.py b/xarray/core/types.py index 588d15dc35d..1e4b7102edd 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -190,6 +190,8 @@ def copy( # FYI in some cases we don't allow `None`, which this doesn't take account of. # FYI the `str` is for a size string, e.g. "16MB", supported by dask. T_ChunkDim: TypeAlias = Union[str, int, Literal["auto"], None, tuple[int, ...]] +T_FreqStr: TypeAlias = str +T_ChunkDim: TypeAlias = Union[T_FreqStr, int, Literal["auto"], None, tuple[int, ...]] # We allow the tuple form of this (though arguably we could transition to named dims only) T_Chunks: TypeAlias = Union[T_ChunkDim, Mapping[Any, T_ChunkDim]] T_NormalizedChunks = tuple[tuple[int, ...], ...] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f6829861776..19ad9a84b4b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -1192,6 +1192,45 @@ def get_dask_names(ds): ): data.chunk({"foo": 10}) + @requires_dask + @pytest.mark.parametrize( + "calendar", + ( + "standard", + pytest.param( + "gregorian", + marks=pytest.mark.skipif(not has_cftime, reason="needs cftime"), + ), + ), + ) + @pytest.mark.parametrize("freq", ["D", "W", "5ME", "YE"]) + def test_chunk_by_frequency(self, freq, calendar) -> None: + import dask.array + + N = 365 * 2 + ds = Dataset( + { + "pr": ("time", dask.array.random.random((N), chunks=(20))), + "ones": ("time", np.ones((N,))), + }, + coords={ + "time": xr.date_range( + "2001-01-01", periods=N, freq="D", calendar=calendar + ) + }, + ) + actual = ds.chunk(time=freq).chunksizes["time"] + expected = tuple(ds.ones.resample(time=freq).sum().data.tolist()) + assert expected == actual + + def test_chunk_by_frequecy_errors(self): + ds = Dataset({"foo": ("x", [1, 2, 3])}) + with pytest.raises(ValueError, match="virtual variable"): + ds.chunk(x="YE") + ds["x"] = ("x", [1, 2, 3]) + with pytest.raises(ValueError, match="datetime variables"): + ds.chunk(x="YE") + @requires_dask def test_dask_is_lazy(self) -> None: store = InaccessibleVariableDataStore()