diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 10f6b23ca66..c082bcaa263 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,6 +24,10 @@ Breaking changes New Features ~~~~~~~~~~~~ +- Control over attributes of result in :py:func:`merge`, :py:func:`concat`, + :py:func:`combine_by_coords` and :py:func:`combine_nested` using + combine_attrs keyword argument. (:issue:`3865`, :pull:`3877`) + By `John Omotani `_ Bug fixes diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 1fa2df00352..1f990457798 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -155,6 +155,7 @@ def _combine_nd( compat="no_conflicts", fill_value=dtypes.NA, join="outer", + combine_attrs="drop", ): """ Combines an N-dimensional structure of datasets into one by applying a @@ -202,13 +203,21 @@ def _combine_nd( compat=compat, fill_value=fill_value, join=join, + combine_attrs=combine_attrs, ) (combined_ds,) = combined_ids.values() return combined_ds def _combine_all_along_first_dim( - combined_ids, dim, data_vars, coords, compat, fill_value=dtypes.NA, join="outer" + combined_ids, + dim, + data_vars, + coords, + compat, + fill_value=dtypes.NA, + join="outer", + combine_attrs="drop", ): # Group into lines of datasets which must be combined along dim @@ -223,7 +232,7 @@ def _combine_all_along_first_dim( combined_ids = dict(sorted(group)) datasets = combined_ids.values() new_combined_ids[new_id] = _combine_1d( - datasets, dim, compat, data_vars, coords, fill_value, join + datasets, dim, compat, data_vars, coords, fill_value, join, combine_attrs ) return new_combined_ids @@ -236,6 +245,7 @@ def _combine_1d( coords="different", fill_value=dtypes.NA, join="outer", + combine_attrs="drop", ): """ Applies either concat or merge to 1D list of datasets depending on value @@ -252,6 +262,7 @@ def _combine_1d( compat=compat, fill_value=fill_value, join=join, + combine_attrs=combine_attrs, ) except ValueError as err: if "encountered unexpected variable" in str(err): @@ -265,7 +276,13 @@ def _combine_1d( else: raise else: - combined = merge(datasets, compat=compat, fill_value=fill_value, join=join) + combined = merge( + datasets, + compat=compat, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs, + ) return combined @@ -284,6 +301,7 @@ def _nested_combine( ids, fill_value=dtypes.NA, join="outer", + combine_attrs="drop", ): if len(datasets) == 0: @@ -311,6 +329,7 @@ def _nested_combine( coords=coords, fill_value=fill_value, join=join, + combine_attrs=combine_attrs, ) return combined @@ -323,6 +342,7 @@ def combine_nested( coords="different", fill_value=dtypes.NA, join="outer", + combine_attrs="drop", ): """ Explicitly combine an N-dimensional grid of datasets into one by using a @@ -390,6 +410,16 @@ def combine_nested( - 'override': if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, + default 'drop' + String indicating how to combine attrs of the objects being merged: + + - 'drop': empty attrs on returned Dataset. + - 'identical': all attrs must be the same on every object. + - 'no_conflicts': attrs from all objects are combined, any that have + the same name must also have the same value. + - 'override': skip comparing and copy attrs from the first dataset to + the result. Returns ------- @@ -468,6 +498,7 @@ def combine_nested( ids=False, fill_value=fill_value, join=join, + combine_attrs=combine_attrs, ) @@ -482,6 +513,7 @@ def combine_by_coords( coords="different", fill_value=dtypes.NA, join="outer", + combine_attrs="no_conflicts", ): """ Attempt to auto-magically combine the given datasets into one by using @@ -557,6 +589,16 @@ def combine_by_coords( - 'override': if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, + default 'drop' + String indicating how to combine attrs of the objects being merged: + + - 'drop': empty attrs on returned Dataset. + - 'identical': all attrs must be the same on every object. + - 'no_conflicts': attrs from all objects are combined, any that have + the same name must also have the same value. + - 'override': skip comparing and copy attrs from the first dataset to + the result. Returns ------- @@ -700,6 +742,7 @@ def combine_by_coords( compat=compat, fill_value=fill_value, join=join, + combine_attrs=combine_attrs, ) # Check the overall coordinates are monotonically increasing @@ -717,6 +760,7 @@ def combine_by_coords( compat=compat, fill_value=fill_value, join=join, + combine_attrs=combine_attrs, ) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 96b4be15d1b..7741cbb826b 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -3,7 +3,7 @@ from . import dtypes, utils from .alignment import align from .duck_array_ops import lazy_array_equiv -from .merge import _VALID_COMPAT, unique_variable +from .merge import _VALID_COMPAT, merge_attrs, unique_variable from .variable import IndexVariable, Variable, as_variable from .variable import concat as concat_vars @@ -17,6 +17,7 @@ def concat( positions=None, fill_value=dtypes.NA, join="outer", + combine_attrs="override", ): """Concatenate xarray objects along a new or existing dimension. @@ -92,15 +93,21 @@ def concat( - 'override': if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, + default 'override + String indicating how to combine attrs of the objects being merged: + + - 'drop': empty attrs on returned Dataset. + - 'identical': all attrs must be the same on every object. + - 'no_conflicts': attrs from all objects are combined, any that have + the same name must also have the same value. + - 'override': skip comparing and copy attrs from the first dataset to + the result. Returns ------- concatenated : type of objs - Notes - ----- - Each concatenated Variable preserves corresponding ``attrs`` from the first element of ``objs``. - See also -------- merge @@ -132,7 +139,9 @@ def concat( "can only concatenate xarray Dataset and DataArray " "objects, got %s" % type(first_obj) ) - return f(objs, dim, data_vars, coords, compat, positions, fill_value, join) + return f( + objs, dim, data_vars, coords, compat, positions, fill_value, join, combine_attrs + ) def _calc_concat_dim_coord(dim): @@ -306,6 +315,7 @@ def _dataset_concat( positions, fill_value=dtypes.NA, join="outer", + combine_attrs="override", ): """ Concatenate a sequence of datasets along a new or existing dimension @@ -362,7 +372,7 @@ def _dataset_concat( result_vars.update(dim_coords) # assign attrs and encoding from first dataset - result_attrs = datasets[0].attrs + result_attrs = merge_attrs([ds.attrs for ds in datasets], combine_attrs) result_encoding = datasets[0].encoding # check that global attributes are fixed across all datasets if necessary @@ -425,6 +435,7 @@ def _dataarray_concat( positions, fill_value=dtypes.NA, join="outer", + combine_attrs="override", ): arrays = list(arrays) @@ -453,5 +464,12 @@ def _dataarray_concat( positions, fill_value=fill_value, join=join, + combine_attrs="drop", ) - return arrays[0]._from_temp_dataset(ds, name) + + merged_attrs = merge_attrs([da.attrs for da in arrays], combine_attrs) + + result = arrays[0]._from_temp_dataset(ds, name) + result.attrs = merged_attrs + + return result diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 324e7ccd290..232fb86144e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -475,7 +475,13 @@ def _to_dataset_whole( dataset = Dataset._construct_direct(variables, coord_names, indexes=indexes) return dataset - def to_dataset(self, dim: Hashable = None, *, name: Hashable = None) -> Dataset: + def to_dataset( + self, + dim: Hashable = None, + *, + name: Hashable = None, + promote_attrs: bool = False, + ) -> Dataset: """Convert a DataArray to a Dataset. Parameters @@ -487,6 +493,8 @@ def to_dataset(self, dim: Hashable = None, *, name: Hashable = None) -> Dataset: name : hashable, optional Name to substitute for this array's name. Only valid if ``dim`` is not provided. + promote_attrs : bool, default False + Set to True to shallow copy attrs of DataArray to returned Dataset. Returns ------- @@ -500,9 +508,14 @@ def to_dataset(self, dim: Hashable = None, *, name: Hashable = None) -> Dataset: if dim is not None: if name is not None: raise TypeError("cannot supply both dim and name arguments") - return self._to_dataset_split(dim) + result = self._to_dataset_split(dim) else: - return self._to_dataset_whole(name) + result = self._to_dataset_whole(name) + + if promote_attrs: + result.attrs = dict(self.attrs) + + return result @property def name(self) -> Optional[Hashable]: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b7ce0ec4e1e..6f96e4f469c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -536,7 +536,7 @@ def __init__( if isinstance(coords, Dataset): coords = coords.variables - variables, coord_names, dims, indexes = merge_data_and_coords( + variables, coord_names, dims, indexes, _ = merge_data_and_coords( data_vars, coords, compat="broadcast_equals" ) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 1d1b8d39a20..fea94246471 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -20,7 +20,7 @@ from . import dtypes, pdcompat from .alignment import deep_align from .duck_array_ops import lazy_array_equiv -from .utils import Frozen, dict_equiv +from .utils import Frozen, compat_dict_union, dict_equiv from .variable import Variable, as_variable, assert_unique_multiindex_level_names if TYPE_CHECKING: @@ -491,17 +491,54 @@ def assert_valid_explicit_coords(variables, dims, explicit_coords): ) +def merge_attrs(variable_attrs, combine_attrs): + """Combine attributes from different variables according to combine_attrs + """ + if not variable_attrs: + # no attributes to merge + return None + + if combine_attrs == "drop": + return {} + elif combine_attrs == "override": + return variable_attrs[0] + elif combine_attrs == "no_conflicts": + result = dict(variable_attrs[0]) + for attrs in variable_attrs[1:]: + try: + result = compat_dict_union(result, attrs) + except ValueError: + raise MergeError( + "combine_attrs='no_conflicts', but some values are not " + "the same. Merging %s with %s" % (str(result), str(attrs)) + ) + return result + elif combine_attrs == "identical": + result = dict(variable_attrs[0]) + for attrs in variable_attrs[1:]: + if not dict_equiv(result, attrs): + raise MergeError( + "combine_attrs='identical', but attrs differ. First is %s " + ", other is %s." % (str(result), str(attrs)) + ) + return result + else: + raise ValueError("Unrecognised value for combine_attrs=%s" % combine_attrs) + + class _MergeResult(NamedTuple): variables: Dict[Hashable, Variable] coord_names: Set[Hashable] dims: Dict[Hashable, int] indexes: Dict[Hashable, pd.Index] + attrs: Dict[Hashable, Any] def merge_core( objects: Iterable["CoercibleMapping"], compat: str = "broadcast_equals", join: str = "outer", + combine_attrs: Optional[str] = "override", priority_arg: Optional[int] = None, explicit_coords: Optional[Sequence] = None, indexes: Optional[Mapping[Hashable, pd.Index]] = None, @@ -519,6 +556,8 @@ def merge_core( Compatibility checks to use when merging variables. join : {'outer', 'inner', 'left', 'right'}, optional How to combine objects with different indexes. + combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, optional + How to combine attributes of objects priority_arg : integer, optional Optional argument in `objects` that takes precedence over the others. explicit_coords : set, optional @@ -536,12 +575,15 @@ def merge_core( Set of coordinate names. dims : dict Dictionary mapping from dimension names to sizes. + attrs : dict + Dictionary of attributes Raises ------ MergeError if the merge cannot be done successfully. """ - from .dataset import calculate_dimensions + from .dataarray import DataArray + from .dataset import Dataset, calculate_dimensions _assert_compat_valid(compat) @@ -571,7 +613,16 @@ def merge_core( "coordinates or not in the merged result: %s" % ambiguous_coords ) - return _MergeResult(variables, coord_names, dims, out_indexes) + attrs = merge_attrs( + [ + var.attrs + for var in coerced + if isinstance(var, Dataset) or isinstance(var, DataArray) + ], + combine_attrs, + ) + + return _MergeResult(variables, coord_names, dims, out_indexes, attrs) def merge( @@ -579,6 +630,7 @@ def merge( compat: str = "no_conflicts", join: str = "outer", fill_value: object = dtypes.NA, + combine_attrs: str = "drop", ) -> "Dataset": """Merge any number of xarray objects into a single Dataset as variables. @@ -614,6 +666,16 @@ def merge( dimension must have the same size in all objects. fill_value : scalar, optional Value to use for newly missing values + combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, + default 'drop' + String indicating how to combine attrs of the objects being merged: + + - 'drop': empty attrs on returned Dataset. + - 'identical': all attrs must be the same on every object. + - 'no_conflicts': attrs from all objects are combined, any that have + the same name must also have the same value. + - 'override': skip comparing and copy attrs from the first dataset to + the result. Returns ------- @@ -787,10 +849,16 @@ def merge( "Dataset(s), DataArray(s), and dictionaries." ) - obj = obj.to_dataset() if isinstance(obj, DataArray) else obj + obj = obj.to_dataset(promote_attrs=True) if isinstance(obj, DataArray) else obj dict_like_objects.append(obj) - merge_result = merge_core(dict_like_objects, compat, join, fill_value=fill_value) + merge_result = merge_core( + dict_like_objects, + compat, + join, + combine_attrs=combine_attrs, + fill_value=fill_value, + ) merged = Dataset._construct_direct(**merge_result._asdict()) return merged @@ -861,4 +929,9 @@ def dataset_update_method( if coord_names: other[key] = value.drop_vars(coord_names) - return merge_core([dataset, other], priority_arg=1, indexes=dataset.indexes) + return merge_core( + [dataset, other], + priority_arg=1, + indexes=dataset.indexes, + combine_attrs="override", + ) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index e335365d5ca..5570f9e9a80 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -184,7 +184,7 @@ def peek_at(iterable: Iterable[T]) -> Tuple[T, Iterator[T]]: def update_safety_check( - first_dict: MutableMapping[K, V], + first_dict: Mapping[K, V], second_dict: Mapping[K, V], compat: Callable[[V, V], bool] = equivalent, ) -> None: @@ -361,6 +361,9 @@ def ordered_dict_intersection( Binary operator to determine if two values are compatible. By default, checks for equivalence. + # TODO: Rename to compat_dict_intersection, as we do not use OrderedDicts + # any more. + Returns ------- intersection : dict @@ -371,6 +374,35 @@ def ordered_dict_intersection( return new_dict +def compat_dict_union( + first_dict: Mapping[K, V], + second_dict: Mapping[K, V], + compat: Callable[[V, V], bool] = equivalent, +) -> MutableMapping[K, V]: + """Return the union of two dictionaries as a new dictionary. + + An exception is raised if any keys are found in both dictionaries and the + values are not compatible. + + Parameters + ---------- + first_dict, second_dict : dict-like + Mappings to merge. + compat : function, optional + Binary operator to determine if two values are compatible. By default, + checks for equivalence. + + Returns + ------- + union : dict + union of the contents. + """ + new_dict = dict(first_dict) + update_safety_check(first_dict, second_dict, compat) + new_dict.update(second_dict) + return new_dict + + class Frozen(Mapping[K, V]): """Wrapper around an object implementing the mapping interface to make it immutable. If you really want to modify the mapping, the mutable version is diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index eb2c6e1dbf7..c3f981f10d1 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -503,6 +503,49 @@ def test_auto_combine_2d(self): result = combine_nested(datasets, concat_dim=["dim1", "dim2"]) assert_equal(result, expected) + def test_auto_combine_2d_combine_attrs_kwarg(self): + ds = create_test_data + + partway1 = concat([ds(0), ds(3)], dim="dim1") + partway2 = concat([ds(1), ds(4)], dim="dim1") + partway3 = concat([ds(2), ds(5)], dim="dim1") + expected = concat([partway1, partway2, partway3], dim="dim2") + + expected_dict = {} + expected_dict["drop"] = expected.copy(deep=True) + expected_dict["drop"].attrs = {} + expected_dict["no_conflicts"] = expected.copy(deep=True) + expected_dict["no_conflicts"].attrs = { + "a": 1, + "b": 2, + "c": 3, + "d": 4, + "e": 5, + "f": 6, + } + expected_dict["override"] = expected.copy(deep=True) + expected_dict["override"].attrs = {"a": 1} + + datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4), ds(5)]] + + datasets[0][0].attrs = {"a": 1} + datasets[0][1].attrs = {"a": 1, "b": 2} + datasets[0][2].attrs = {"a": 1, "c": 3} + datasets[1][0].attrs = {"a": 1, "d": 4} + datasets[1][1].attrs = {"a": 1, "e": 5} + datasets[1][2].attrs = {"a": 1, "f": 6} + + with raises_regex(ValueError, "combine_attrs='identical'"): + result = combine_nested( + datasets, concat_dim=["dim1", "dim2"], combine_attrs="identical" + ) + + for combine_attrs in expected_dict: + result = combine_nested( + datasets, concat_dim=["dim1", "dim2"], combine_attrs=combine_attrs + ) + assert_identical(result, expected_dict[combine_attrs]) + def test_combine_nested_missing_data_new_dim(self): # Your data includes "time" and "station" dimensions, and each year's # data has a different set of stations. @@ -642,6 +685,52 @@ def test_combine_coords_join_exact(self): with raises_regex(ValueError, "indexes along dimension"): combine_nested(objs, concat_dim="x", join="exact") + @pytest.mark.parametrize( + "combine_attrs, expected", + [ + ("drop", Dataset({"x": [0, 1], "y": [0, 1]}, attrs={})), + ( + "no_conflicts", + Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1, "b": 2}), + ), + ("override", Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1})), + ], + ) + def test_combine_coords_combine_attrs(self, combine_attrs, expected): + objs = [ + Dataset({"x": [0], "y": [0]}, attrs={"a": 1}), + Dataset({"x": [1], "y": [1]}, attrs={"a": 1, "b": 2}), + ] + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs=combine_attrs + ) + assert_identical(expected, actual) + + if combine_attrs == "no_conflicts": + objs[1].attrs["a"] = 2 + with raises_regex(ValueError, "combine_attrs='no_conflicts'"): + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs=combine_attrs + ) + + def test_combine_coords_combine_attrs_identical(self): + objs = [ + Dataset({"x": [0], "y": [0]}, attrs={"a": 1}), + Dataset({"x": [1], "y": [1]}, attrs={"a": 1}), + ] + expected = Dataset({"x": [0, 1], "y": [0, 1]}, attrs={"a": 1}) + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs="identical" + ) + assert_identical(expected, actual) + + objs[1].attrs["b"] = 2 + + with raises_regex(ValueError, "combine_attrs='identical'"): + actual = combine_nested( + objs, concat_dim="x", join="outer", combine_attrs="identical" + ) + def test_infer_order_from_coords(self): data = create_test_data() objs = [data.isel(dim2=slice(4, 9)), data.isel(dim2=slice(4))] diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 1a498496c03..e5038dd4af2 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -256,6 +256,28 @@ def test_concat_join_kwarg(self): ) assert_identical(actual, expected) + def test_concat_combine_attrs_kwarg(self): + ds1 = Dataset({"a": ("x", [0])}, coords={"x": [0]}, attrs={"b": 42}) + ds2 = Dataset({"a": ("x", [0])}, coords={"x": [1]}, attrs={"b": 42, "c": 43}) + + expected = {} + expected["drop"] = Dataset({"a": ("x", [0, 0])}, {"x": [0, 1]}) + expected["no_conflicts"] = Dataset( + {"a": ("x", [0, 0])}, {"x": [0, 1]}, {"b": 42, "c": 43} + ) + expected["override"] = Dataset({"a": ("x", [0, 0])}, {"x": [0, 1]}, {"b": 42}) + + with raises_regex(ValueError, "combine_attrs='identical'"): + actual = concat([ds1, ds2], dim="x", combine_attrs="identical") + with raises_regex(ValueError, "combine_attrs='no_conflicts'"): + ds3 = ds2.copy(deep=True) + ds3.attrs["b"] = 44 + actual = concat([ds1, ds3], dim="x", combine_attrs="no_conflicts") + + for combine_attrs in expected: + actual = concat([ds1, ds2], dim="x", combine_attrs=combine_attrs) + assert_identical(actual, expected[combine_attrs]) + def test_concat_promote_shape(self): # mixed dims within variables objs = [Dataset({}, {"x": 0}), Dataset({"x": [1]})] @@ -469,6 +491,30 @@ def test_concat_join_kwarg(self): actual = concat([ds1, ds2], join=join, dim="x") assert_equal(actual, expected[join].to_array()) + def test_concat_combine_attrs_kwarg(self): + da1 = DataArray([0], coords=[("x", [0])], attrs={"b": 42}) + da2 = DataArray([0], coords=[("x", [1])], attrs={"b": 42, "c": 43}) + + expected = {} + expected["drop"] = DataArray([0, 0], coords=[("x", [0, 1])]) + expected["no_conflicts"] = DataArray( + [0, 0], coords=[("x", [0, 1])], attrs={"b": 42, "c": 43} + ) + expected["override"] = DataArray( + [0, 0], coords=[("x", [0, 1])], attrs={"b": 42} + ) + + with raises_regex(ValueError, "combine_attrs='identical'"): + actual = concat([da1, da2], dim="x", combine_attrs="identical") + with raises_regex(ValueError, "combine_attrs='no_conflicts'"): + da3 = da2.copy(deep=True) + da3.attrs["b"] = 44 + actual = concat([da1, da3], dim="x", combine_attrs="no_conflicts") + + for combine_attrs in expected: + actual = concat([da1, da2], dim="x", combine_attrs=combine_attrs) + assert_identical(actual, expected[combine_attrs]) + @pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {})) @pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {})) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index fbd9810f285..4f19dc2a9cf 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3750,9 +3750,16 @@ def test_to_dataset_whole(self): expected = Dataset({"foo": ("x", [1, 2])}) assert_identical(expected, actual) - named = DataArray([1, 2], dims="x", name="foo") + named = DataArray([1, 2], dims="x", name="foo", attrs={"y": "testattr"}) actual = named.to_dataset() - expected = Dataset({"foo": ("x", [1, 2])}) + expected = Dataset({"foo": ("x", [1, 2], {"y": "testattr"})}) + assert_identical(expected, actual) + + # Test promoting attrs + actual = named.to_dataset(promote_attrs=True) + expected = Dataset( + {"foo": ("x", [1, 2], {"y": "testattr"})}, attrs={"y": "testattr"} + ) assert_identical(expected, actual) with pytest.raises(TypeError): diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index 6c8f3f65657..9057575b38c 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -3,6 +3,7 @@ import xarray as xr from xarray.core import dtypes, merge +from xarray.core.merge import MergeError from xarray.testing import assert_identical from . import raises_regex @@ -49,6 +50,65 @@ def test_merge_dataarray_unnamed(self): with raises_regex(ValueError, "without providing an explicit name"): xr.merge([data]) + def test_merge_arrays_attrs_default(self): + var1_attrs = {"a": 1, "b": 2} + var2_attrs = {"a": 1, "c": 3} + expected_attrs = {} + + data = create_test_data() + data.var1.attrs = var1_attrs + data.var2.attrs = var2_attrs + actual = xr.merge([data.var1, data.var2]) + expected = data[["var1", "var2"]] + expected.attrs = expected_attrs + assert actual.identical(expected) + + @pytest.mark.parametrize( + "combine_attrs, var1_attrs, var2_attrs, expected_attrs, " "expect_exception", + [ + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 1, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + False, + ), + ("no_conflicts", {"a": 1, "b": 2}, {}, {"a": 1, "b": 2}, False), + ("no_conflicts", {}, {"a": 1, "c": 3}, {"a": 1, "c": 3}, False), + ( + "no_conflicts", + {"a": 1, "b": 2}, + {"a": 4, "c": 3}, + {"a": 1, "b": 2, "c": 3}, + True, + ), + ("drop", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "b": 2}, {"a": 1, "b": 2}, False), + ("identical", {"a": 1, "b": 2}, {"a": 1, "c": 3}, {"a": 1, "b": 2}, True), + ( + "override", + {"a": 1, "b": 2}, + {"a": 4, "b": 5, "c": 3}, + {"a": 1, "b": 2}, + False, + ), + ], + ) + def test_merge_arrays_attrs( + self, combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception + ): + data = create_test_data() + data.var1.attrs = var1_attrs + data.var2.attrs = var2_attrs + if expect_exception: + with raises_regex(MergeError, "combine_attrs"): + actual = xr.merge([data.var1, data.var2], combine_attrs=combine_attrs) + else: + actual = xr.merge([data.var1, data.var2], combine_attrs=combine_attrs) + expected = data[["var1", "var2"]] + expected.attrs = expected_attrs + assert actual.identical(expected) + def test_merge_dicts_simple(self): actual = xr.merge([{"foo": 0}, {"bar": "one"}, {"baz": 3.5}]) expected = xr.Dataset({"foo": 0, "bar": "one", "baz": 3.5}) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index af87b94393d..ddca6c57064 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -9,7 +9,7 @@ from xarray.core import duck_array_ops, utils from xarray.core.utils import either_dict_or_kwargs -from . import assert_array_equal, requires_cftime, requires_dask +from . import assert_array_equal, raises_regex, requires_cftime, requires_dask from .test_coding_times import _all_cftime_date_types @@ -124,6 +124,15 @@ def test_ordered_dict_intersection(self): assert {"b": "B"} == utils.ordered_dict_intersection(self.x, self.y) assert {} == utils.ordered_dict_intersection(self.x, self.z) + def test_compat_dict_union(self): + assert {"a": "A", "b": "B", "c": "C"} == utils.compat_dict_union(self.x, self.y) + with raises_regex( + ValueError, + "unsafe to merge dictionaries without " + "overriding values; conflicting key", + ): + utils.compat_dict_union(self.x, self.z) + def test_dict_equiv(self): x = {} x["a"] = 3