Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply to dataset #4863

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5655571
add a apply_to_dataset method
keewis Jan 30, 2021
90f8d55
write a test for apply_to_dataset on a DataArray
keewis Jan 30, 2021
fd2c897
also add a test for dataset
keewis Jan 30, 2021
857c783
convert apply_to_dataset to a top-level function
keewis Feb 5, 2021
57e94b6
update whats-new.rst
keewis Feb 5, 2021
cdb0f3d
add the new function to api.rst [skip-ci]
keewis Feb 5, 2021
1d81a49
rephrase the note [skip-ci]
keewis Feb 5, 2021
88fe863
add a see also section [skip-ci]
keewis Feb 5, 2021
0daf42d
add examples [skip-ci]
keewis Feb 5, 2021
ef3f791
Merge branch 'master' into apply-to-dataset
keewis Feb 7, 2021
638d61c
Merge branch 'master' into apply-to-dataset
keewis Feb 11, 2021
559d8ef
rename to call_on_dataset
keewis Mar 15, 2021
0c424bf
preserve the name as much as possible
keewis Mar 15, 2021
8db9e7e
update api.rst
keewis Mar 15, 2021
c902dfe
Merge branch 'master' into apply-to-dataset
keewis Mar 15, 2021
43bf70d
update whats-new.rst
keewis Mar 15, 2021
31645e5
remove the notes
keewis Mar 15, 2021
293d9c1
remove the no-op
keewis Mar 15, 2021
d0de1ca
don't rename to None
keewis Mar 15, 2021
a822232
rename to "<this-array>"
keewis Mar 15, 2021
d278919
rewrite [skip-ci]
keewis Mar 15, 2021
0669da9
Merge branch 'master' into apply-to-dataset
keewis Mar 15, 2021
97d4338
rename back to None
keewis Mar 15, 2021
48109db
Merge branch 'master' into apply-to-dataset
keewis Mar 28, 2021
b15d45e
Merge branch 'master' into apply-to-dataset
keewis Apr 5, 2021
371f509
introduce a mandatory name parameter to use as a name for the data va…
keewis May 10, 2021
8f37872
Merge branch 'master' into apply-to-dataset
keewis May 10, 2021
c9459f7
move to the new section in whats-new.rst
keewis May 10, 2021
021ad36
fix the tests
keewis May 11, 2021
dcb747b
Merge branch 'master' into apply-to-dataset
keewis May 31, 2021
7081e15
update the input and expected values
keewis May 31, 2021
52a39f3
add the missing name for the dataset call
keewis May 31, 2021
fcfaaa5
use DataArray.to_dataset instead
keewis May 31, 2021
f2d2880
only convert if the result is a Dataset
keewis May 31, 2021
b59dd1e
Merge branch 'master' into apply-to-dataset
dcherian Jun 21, 2021
12400cb
Merge branch 'main' into apply-to-dataset
keewis Jul 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Top-level functions
apply_ufunc
align
broadcast
call_on_dataset
concat
merge
combine_by_coords
Expand Down
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ v0.18.3 (unreleased)

New Features
~~~~~~~~~~~~

- Add :py:func:`call_on_dataset` as a way to apply functions expecting
:py:class:`Dataset` objects to :py:class:`DataArray` objects (:issue:`4837`, :pull:`4863`).
By `Justus Magin <https://github.com/keewis>`_.
- Xarray now uses consolidated metadata by default when writing and reading Zarr
stores (:issue:`5251`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
Expand Down
12 changes: 11 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@
from .core.alignment import align, broadcast
from .core.combine import combine_by_coords, combine_nested
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where
from .core.computation import (
apply_ufunc,
call_on_dataset,
corr,
cov,
dot,
polyval,
unify_chunks,
where,
)
from .core.concat import concat
from .core.dataarray import DataArray
from .core.dataset import Dataset
Expand Down Expand Up @@ -46,6 +55,7 @@
# Top-level functions
"align",
"apply_ufunc",
"call_on_dataset",
"as_variable",
"broadcast",
"cftime_range",
Expand Down
88 changes: 88 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,94 @@ def apply_ufunc(
return apply_array_ufunc(func, *args, dask=dask)


def call_on_dataset(func, obj, name, *args, **kwargs):
"""apply a function expecting a Dataset to a xarray object

Parameters
----------
func : callable
A function expecting a Dataset as its first parameter.
obj : DataArray or Dataset
The dataset to apply ``func`` to. If a ``DataArray``, convert it to a single
variable ``Dataset`` first.
name : hashable
A intermediate name to use as the name of the data variable. If the DataArray
already had a name, it will be restored after converting back.
*args, **kwargs
Additional arguments to ``func``

Returns
-------
DataArray or Dataset
The result of ``func(obj, *args, **kwargs)`` with the same type as ``obj``.

See Also
--------
Dataset.map
Dataset.pipe
DataArray.pipe

Examples
--------
>>> def f(ds):
... return xr.Dataset(
... {
... name: var * var.attrs.get("scale", 1)
... for name, var in ds.data_vars.items()
... },
... coords=ds.coords,
... attrs=ds.attrs,
... )
...
>>> ds = xr.Dataset(
... {"a": ("x", [3, 4], {"scale": 0.5}), "b": ("x", [-1, 1], {"scale": 1.5})},
... coords={"x": [0, 1]},
... attrs={"attr": "value"},
... )
>>> ds
<xarray.Dataset>
Dimensions: (x: 2)
Coordinates:
* x (x) int64 0 1
Data variables:
a (x) int64 3 4
b (x) int64 -1 1
Attributes:
attr: value
>>> xr.call_on_dataset(f, ds, name="<this-array>")
<xarray.Dataset>
Dimensions: (x: 2)
Coordinates:
* x (x) int64 0 1
Data variables:
a (x) float64 1.5 2.0
b (x) float64 -1.5 1.5
Attributes:
attr: value
>>> xr.call_on_dataset(f, ds.a, name="<this-array>")
<xarray.DataArray 'a' (x: 2)>
array([1.5, 2. ])
Coordinates:
* x (x) int64 0 1
>>> xr.call_on_dataset(lambda ds: list(ds.variables.keys()), ds.a, name="data")
['x', 'data']
"""
from .dataarray import DataArray, Dataset
from .parallel import dataset_to_dataarray

if isinstance(obj, DataArray):
ds = obj.to_dataset(name=name)
else:
ds = obj

result = func(ds, *args, **kwargs)

if isinstance(obj, DataArray) and isinstance(result, Dataset):
result = dataset_to_dataarray(result).rename(obj.name)

return result


def cov(da_a, da_b, dim=None, ddof=1):
"""
Compute covariance between two DataArray objects along a shared dimension.
Expand Down
61 changes: 61 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,67 @@ def test_apply_groupby_add():
add(data_array.groupby("y"), data_array.groupby("x"))


@pytest.mark.parametrize(
["obj", "attrs", "expected"],
(
pytest.param(
xr.DataArray(
[0, 1],
dims="x",
),
{None: {"a": 1}},
xr.DataArray([0, 1], dims="x", attrs={"a": 1}),
id="unnamed DataArray",
),
pytest.param(
xr.DataArray(
[0, 1],
dims="x",
name="b",
),
{None: {"a": 1}},
xr.DataArray([0, 1], dims="x", attrs={"a": 1}, name="b"),
id="named DataArray",
),
pytest.param(
xr.Dataset(
{"a": ("x", [1, 2]), "b": ("x", [0, 1])},
coords={
"x": ("x", [-1, 1]),
"u": ("x", [2, 3]),
},
),
{"a": {"a": 1}, "u": {"b": 2}},
xr.Dataset(
{"a": ("x", [1, 2], {"a": 1}), "b": ("x", [0, 1])},
coords={"x": [-1, 1], "u": ("x", [2, 3], {"b": 2})},
),
id="Dataset",
),
),
)
def test_call_on_dataset(obj, attrs, expected):
temporary_name = "<this-array>"

def attach_attrs(ds, attrs):
new_ds = ds.copy()
for n, v in new_ds.variables.items():
if n == temporary_name:
n = None

if n not in attrs:
continue

v.attrs.update(attrs[n])

return new_ds

actual = xr.call_on_dataset(
lambda ds: attach_attrs(ds, attrs), obj, name=temporary_name
)
assert_identical(actual, expected)


def test_unified_dim_sizes():
assert unified_dim_sizes([xr.Variable((), 0)]) == {}
assert unified_dim_sizes([xr.Variable("x", [1]), xr.Variable("x", [1])]) == {"x": 1}
Expand Down