diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 62da8bf1ea2..d07a2741eef 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v0.18.1 (unreleased) New Features ~~~~~~~~~~~~ +- allow passing ``combine_attrs`` strategy names to the ``keep_attrs`` parameter of + :py:func:`apply_ufunc` (:pull:`5041`) + By `Justus Magin `_. - :py:meth:`Dataset.interp` now allows interpolation with non-numerical datatypes, such as booleans, instead of dropping them. (:issue:`4761` :pull:`5008`). By `Jimmy Westling `_. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 2bc3ba921a7..12dded3e158 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -27,8 +27,8 @@ from . import dtypes, duck_array_ops, utils from .alignment import align, deep_align -from .merge import merge_coordinates_without_align -from .options import OPTIONS +from .merge import merge_attrs, merge_coordinates_without_align +from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array from .utils import is_dict_like from .variable import Variable @@ -50,6 +50,11 @@ def _first_of_type(args, kind): raise ValueError("This should be unreachable.") +def _all_of_type(args, kind): + """Return all objects of type 'kind'""" + return [arg for arg in args if isinstance(arg, kind)] + + class _UFuncSignature: """Core dimensions signature for a given function. @@ -202,7 +207,10 @@ def _get_coords_list(args) -> List["Coordinates"]: def build_output_coords( - args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset() + args: list, + signature: _UFuncSignature, + exclude_dims: AbstractSet = frozenset(), + combine_attrs: str = "override", ) -> "List[Dict[Any, Variable]]": """Build output coordinates for an operation. @@ -230,7 +238,7 @@ def build_output_coords( else: # TODO: save these merged indexes, instead of re-computing them later merged_vars, unused_indexes = merge_coordinates_without_align( - coords_list, exclude_dims=exclude_dims + coords_list, exclude_dims=exclude_dims, combine_attrs=combine_attrs ) output_coords = [] @@ -248,7 +256,12 @@ def build_output_coords( def apply_dataarray_vfunc( - func, *args, signature, join="inner", exclude_dims=frozenset(), keep_attrs=False + func, + *args, + signature, + join="inner", + exclude_dims=frozenset(), + keep_attrs="override", ): """Apply a variable level function over DataArray, Variable and/or ndarray objects. @@ -260,12 +273,16 @@ def apply_dataarray_vfunc( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - if keep_attrs: + objs = _all_of_type(args, DataArray) + + if keep_attrs == "drop": + name = result_name(args) + else: first_obj = _first_of_type(args, DataArray) name = first_obj.name - else: - name = result_name(args) - result_coords = build_output_coords(args, signature, exclude_dims) + result_coords = build_output_coords( + args, signature, exclude_dims, combine_attrs=keep_attrs + ) data_vars = [getattr(a, "variable", a) for a in args] result_var = func(*data_vars) @@ -279,13 +296,12 @@ def apply_dataarray_vfunc( (coords,) = result_coords out = DataArray(result_var, coords, name=name, fastpath=True) - if keep_attrs: - if isinstance(out, tuple): - for da in out: - # This is adding attrs in place - da._copy_attrs_from(first_obj) - else: - out._copy_attrs_from(first_obj) + attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) + if isinstance(out, tuple): + for da in out: + da.attrs = attrs + else: + out.attrs = attrs return out @@ -400,7 +416,7 @@ def apply_dataset_vfunc( dataset_join="exact", fill_value=_NO_FILL_VALUE, exclude_dims=frozenset(), - keep_attrs=False, + keep_attrs="override", ): """Apply a variable level function over Dataset, dict of DataArray, DataArray, Variable and/or ndarray objects. @@ -414,15 +430,16 @@ def apply_dataset_vfunc( "dataset_fill_value argument." ) - if keep_attrs: - first_obj = _first_of_type(args, Dataset) + objs = _all_of_type(args, Dataset) if len(args) > 1: args = deep_align( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - list_of_coords = build_output_coords(args, signature, exclude_dims) + list_of_coords = build_output_coords( + args, signature, exclude_dims, combine_attrs=keep_attrs + ) args = [getattr(arg, "data_vars", arg) for arg in args] result_vars = apply_dict_of_variables_vfunc( @@ -435,13 +452,13 @@ def apply_dataset_vfunc( (coord_vars,) = list_of_coords out = _fast_dataset(result_vars, coord_vars) - if keep_attrs: - if isinstance(out, tuple): - for ds in out: - # This is adding attrs in place - ds._copy_attrs_from(first_obj) - else: - out._copy_attrs_from(first_obj) + attrs = merge_attrs([x.attrs for x in objs], combine_attrs=keep_attrs) + if isinstance(out, tuple): + for ds in out: + ds.attrs = attrs + else: + out.attrs = attrs + return out @@ -609,14 +626,12 @@ def apply_variable_ufunc( dask="forbidden", output_dtypes=None, vectorize=False, - keep_attrs=False, + keep_attrs="override", dask_gufunc_kwargs=None, ): """Apply a ndarray level function over Variable and/or ndarray objects.""" from .variable import Variable, as_compatible_data - first_obj = _first_of_type(args, Variable) - dim_sizes = unified_dim_sizes( (a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims ) @@ -736,6 +751,12 @@ def func(*arrays): ) ) + objs = _all_of_type(args, Variable) + attrs = merge_attrs( + [obj.attrs for obj in objs], + combine_attrs=keep_attrs, + ) + output = [] for dims, data in zip(output_dims, result_data): data = as_compatible_data(data) @@ -758,8 +779,7 @@ def func(*arrays): ) ) - if keep_attrs: - var.attrs.update(first_obj.attrs) + var.attrs = attrs output.append(var) if signature.num_outputs == 1: @@ -801,7 +821,7 @@ def apply_ufunc( join: str = "exact", dataset_join: str = "exact", dataset_fill_value: object = _NO_FILL_VALUE, - keep_attrs: bool = False, + keep_attrs: Union[bool, str] = None, kwargs: Mapping = None, dask: str = "forbidden", output_dtypes: Sequence = None, @@ -1098,6 +1118,12 @@ def apply_ufunc( if kwargs: func = functools.partial(func, **kwargs) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + if isinstance(keep_attrs, bool): + keep_attrs = "override" if keep_attrs else "drop" + variables_vfunc = functools.partial( apply_variable_ufunc, func, diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 06da058eb1f..4901c5bf312 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -314,6 +314,7 @@ def merge_coordinates_without_align( objects: "List[Coordinates]", prioritized: Mapping[Hashable, MergeElement] = None, exclude_dims: AbstractSet = frozenset(), + combine_attrs: str = "override", ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: """Merge variables/indexes from coordinates without automatic alignments. @@ -335,7 +336,7 @@ def merge_coordinates_without_align( else: filtered = collected - return merge_collected(filtered, prioritized) + return merge_collected(filtered, prioritized, combine_attrs=combine_attrs) def determine_coords( diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index c76633de831..cbfa61a4482 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -29,6 +29,7 @@ def assert_identical(a, b): """A version of this function which accepts numpy arrays""" + __tracebackhide__ = True from xarray.testing import assert_identical as assert_identical_ if hasattr(a, "identical"): @@ -563,6 +564,401 @@ def add(a, b, keep_attrs): assert_identical(actual.x.attrs, a.x.attrs) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_variable(strategy, attrs, expected, error): + a = xr.Variable("x", [0, 1], attrs=attrs[0]) + b = xr.Variable("x", [0, 1], attrs=attrs[1]) + c = xr.Variable("x", [0, 1], attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.Variable("x", [0, 3], attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataarray(strategy, attrs, expected, error): + a = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[0]) + b = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[1]) + c = xr.DataArray(dims="x", data=[0, 1], attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.DataArray(dims="x", data=[0, 3], attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("variant", ("dim", "coord")) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataarray_variables( + variant, strategy, attrs, expected, error +): + compute_attrs = { + "dim": lambda attrs, default: (attrs, default), + "coord": lambda attrs, default: (default, attrs), + }.get(variant) + + dim_attrs, coord_attrs = compute_attrs(attrs, [{}, {}, {}]) + + a = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[0]), "u": ("x", [0, 1], coord_attrs[0])}, + ) + b = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[1]), "u": ("x", [0, 1], coord_attrs[1])}, + ) + c = xr.DataArray( + dims="x", + data=[0, 1], + coords={"x": ("x", [0, 1], dim_attrs[2]), "u": ("x", [0, 1], coord_attrs[2])}, + ) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + dim_attrs, coord_attrs = compute_attrs(expected, {}) + expected = xr.DataArray( + dims="x", + data=[0, 3], + coords={"x": ("x", [0, 1], dim_attrs), "u": ("x", [0, 1], coord_attrs)}, + ) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataset(strategy, attrs, expected, error): + a = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[0]) + b = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[1]) + c = xr.Dataset({"a": ("x", [0, 1])}, attrs=attrs[2]) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + expected = xr.Dataset({"a": ("x", [0, 3])}, attrs=expected) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + +@pytest.mark.parametrize("variant", ("data", "dim", "coord")) +@pytest.mark.parametrize( + ["strategy", "attrs", "expected", "error"], + ( + pytest.param( + None, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="default", + ), + pytest.param( + False, + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="False", + ), + pytest.param( + True, + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="True", + ), + pytest.param( + "override", + [{"a": 1}, {"a": 2}, {"a": 3}], + {"a": 1}, + False, + id="override", + ), + pytest.param( + "drop", + [{"a": 1}, {"a": 2}, {"a": 3}], + {}, + False, + id="drop", + ), + pytest.param( + "drop_conflicts", + [{"a": 1, "b": 2}, {"b": 1, "c": 3}, {"c": 3, "d": 4}], + {"a": 1, "c": 3, "d": 4}, + False, + id="drop_conflicts", + ), + pytest.param( + "no_conflicts", + [{"a": 1}, {"b": 2}, {"b": 3}], + None, + True, + id="no_conflicts", + ), + ), +) +def test_keep_attrs_strategies_dataset_variables( + variant, strategy, attrs, expected, error +): + compute_attrs = { + "data": lambda attrs, default: (attrs, default, default), + "dim": lambda attrs, default: (default, attrs, default), + "coord": lambda attrs, default: (default, default, attrs), + }.get(variant) + data_attrs, dim_attrs, coord_attrs = compute_attrs(attrs, [{}, {}, {}]) + + a = xr.Dataset( + {"a": ("x", [], data_attrs[0])}, + coords={"x": ("x", [], dim_attrs[0]), "u": ("x", [], coord_attrs[0])}, + ) + b = xr.Dataset( + {"a": ("x", [], data_attrs[1])}, + coords={"x": ("x", [], dim_attrs[1]), "u": ("x", [], coord_attrs[1])}, + ) + c = xr.Dataset( + {"a": ("x", [], data_attrs[2])}, + coords={"x": ("x", [], dim_attrs[2]), "u": ("x", [], coord_attrs[2])}, + ) + + if error: + with pytest.raises(xr.MergeError): + apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + else: + data_attrs, dim_attrs, coord_attrs = compute_attrs(expected, {}) + expected = xr.Dataset( + {"a": ("x", [], data_attrs)}, + coords={"x": ("x", [], dim_attrs), "u": ("x", [], coord_attrs)}, + ) + actual = apply_ufunc(lambda *args: sum(args), a, b, c, keep_attrs=strategy) + + assert_identical(actual, expected) + + def test_dataset_join(): ds0 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]}) ds1 = xr.Dataset({"a": ("x", [99, 3]), "x": [1, 2]})