Skip to content

Commit 82a659d

Browse files
authored
Support DataTree in xarray.concat() (#10846)
* Support DataTree in xarray.concat() * fix docstring * fix mypy errors * Add comment to explain preexisting_dim * Add another unit-test
1 parent 3572f4e commit 82a659d

File tree

3 files changed

+270
-14
lines changed

3 files changed

+270
-14
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ v2025.10.2 (unreleased)
1313
New Features
1414
~~~~~~~~~~~~
1515

16-
- :py:func:`merge` now supports merging :py:class:`DataTree` objects
17-
(:issue:`9790`).
16+
- :py:func:`merge` and :py:func:`concat` now support :py:class:`DataTree`
17+
objects (:issue:`9790`, :issue:`9778`).
1818
By `Stephan Hoyer <https://github.com/shoyer>`_.
1919

2020
Breaking Changes

xarray/structure/concat.py

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131

3232
if TYPE_CHECKING:
33+
from xarray.core.datatree import DataTree
3334
from xarray.core.types import (
3435
CombineAttrsOptions,
3536
CompatOptions,
@@ -40,6 +41,21 @@
4041
T_DataVars = Union[ConcatOptions, Iterable[Hashable], None]
4142

4243

44+
@overload
45+
def concat(
46+
objs: Iterable[DataTree],
47+
dim: Hashable | T_Variable | T_DataArray | pd.Index | Any,
48+
data_vars: T_DataVars | CombineKwargDefault = _DATA_VARS_DEFAULT,
49+
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault = _COORDS_DEFAULT,
50+
compat: CompatOptions | CombineKwargDefault = _COMPAT_CONCAT_DEFAULT,
51+
positions: Iterable[Iterable[int]] | None = None,
52+
fill_value: object = dtypes.NA,
53+
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
54+
combine_attrs: CombineAttrsOptions = "override",
55+
create_index_for_new_dim: bool = True,
56+
) -> DataTree: ...
57+
58+
4359
# TODO: replace dim: Any by 1D array_likes
4460
@overload
4561
def concat(
@@ -87,7 +103,7 @@ def concat(
87103
88104
Parameters
89105
----------
90-
objs : sequence of Dataset and DataArray
106+
objs : sequence of DataArray, Dataset or DataTree
91107
xarray objects to concatenate together. Each object is expected to
92108
consist of variables and coordinates with matching shapes except for
93109
along the concatenated dimension.
@@ -117,9 +133,7 @@ def concat(
117133
coords : {"minimal", "different", "all"} or list of Hashable, optional
118134
These coordinate variables will be concatenated together:
119135
* "minimal": Only coordinates in which the dimension already appears
120-
are included. If concatenating over a dimension _not_
121-
present in any of the objects, then all data variables will
122-
be concatenated along that new dimension.
136+
are included.
123137
* "different": Coordinates which are not equal (ignoring attributes)
124138
across all datasets are also concatenated (as well as all for which
125139
dimension already appears). Beware: this option may load the data
@@ -180,7 +194,8 @@ def concat(
180194
If a callable, it must expect a sequence of ``attrs`` dicts and a context object
181195
as its only parameters.
182196
create_index_for_new_dim : bool, default: True
183-
Whether to create a new ``PandasIndex`` object when the objects being concatenated contain scalar variables named ``dim``.
197+
Whether to create a new ``PandasIndex`` object when the objects being
198+
concatenated contain scalar variables named ``dim``.
184199
185200
Returns
186201
-------
@@ -265,6 +280,7 @@ def concat(
265280
# dimension already exists
266281
from xarray.core.dataarray import DataArray
267282
from xarray.core.dataset import Dataset
283+
from xarray.core.datatree import DataTree
268284

269285
try:
270286
first_obj, objs = utils.peek_at(objs)
@@ -278,7 +294,20 @@ def concat(
278294
f"compat={compat!r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'"
279295
)
280296

281-
if isinstance(first_obj, DataArray):
297+
if isinstance(first_obj, DataTree):
298+
return _datatree_concat(
299+
objs,
300+
dim=dim,
301+
data_vars=data_vars,
302+
coords=coords,
303+
compat=compat,
304+
positions=positions,
305+
fill_value=fill_value,
306+
join=join,
307+
combine_attrs=combine_attrs,
308+
create_index_for_new_dim=create_index_for_new_dim,
309+
)
310+
elif isinstance(first_obj, DataArray):
282311
return _dataarray_concat(
283312
objs,
284313
dim=dim,
@@ -342,7 +371,7 @@ def _calc_concat_over(
342371
datasets: list[T_Dataset],
343372
dim: Hashable,
344373
all_dims: set[Hashable],
345-
data_vars: T_DataVars | CombineKwargDefault,
374+
data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault,
346375
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
347376
compat: CompatOptions | CombineKwargDefault,
348377
) -> tuple[set[Hashable], dict[Hashable, bool], list[int], set[Hashable]]:
@@ -574,7 +603,7 @@ def _parse_datasets(
574603

575604
def _dataset_concat(
576605
datasets: Iterable[T_Dataset],
577-
dim: str | T_Variable | T_DataArray | pd.Index,
606+
dim: Hashable | T_Variable | T_DataArray | pd.Index,
578607
data_vars: T_DataVars | CombineKwargDefault,
579608
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
580609
compat: CompatOptions | CombineKwargDefault,
@@ -583,6 +612,8 @@ def _dataset_concat(
583612
join: JoinOptions | CombineKwargDefault,
584613
combine_attrs: CombineAttrsOptions,
585614
create_index_for_new_dim: bool,
615+
*,
616+
preexisting_dim: bool = False,
586617
) -> T_Dataset:
587618
"""
588619
Concatenate a sequence of datasets along a new or existing dimension
@@ -618,6 +649,11 @@ def _dataset_concat(
618649
all_dims, dim_coords, dims_sizes, coord_names, data_names, vars_order = (
619650
_parse_datasets(datasets)
620651
)
652+
if preexisting_dim:
653+
# When concatenating DataTree objects, a dimension may be pre-existing
654+
# because it exists elsewhere on the trees, even if it does not exist
655+
# on the dataset objects at this node.
656+
all_dims.add(dim_name)
621657
indexed_dim_names = set(dim_coords)
622658

623659
both_data_and_coords = coord_names & data_names
@@ -818,8 +854,8 @@ def get_indexes(name):
818854

819855
def _dataarray_concat(
820856
arrays: Iterable[T_DataArray],
821-
dim: str | T_Variable | T_DataArray | pd.Index,
822-
data_vars: T_DataVars | CombineKwargDefault,
857+
dim: Hashable | T_Variable | T_DataArray | pd.Index,
858+
data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault,
823859
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
824860
compat: CompatOptions | CombineKwargDefault,
825861
positions: Iterable[Iterable[int]] | None,
@@ -877,3 +913,56 @@ def _dataarray_concat(
877913
result.attrs = merged_attrs
878914

879915
return result
916+
917+
918+
def _datatree_concat(
919+
objs: Iterable[DataTree],
920+
dim: Hashable | Variable | T_DataArray | pd.Index | Any,
921+
data_vars: T_DataVars | Iterable[Hashable] | CombineKwargDefault,
922+
coords: ConcatOptions | Iterable[Hashable] | CombineKwargDefault,
923+
compat: CompatOptions | CombineKwargDefault,
924+
positions: Iterable[Iterable[int]] | None,
925+
fill_value: Any,
926+
join: JoinOptions | CombineKwargDefault,
927+
combine_attrs: CombineAttrsOptions,
928+
create_index_for_new_dim: bool,
929+
) -> DataTree:
930+
"""
931+
Concatenate a sequence of datatrees along a new or existing dimension
932+
"""
933+
from xarray.core.datatree import DataTree
934+
from xarray.core.treenode import TreeIsomorphismError, group_subtrees
935+
936+
dim_name, _ = _calc_concat_dim_index(dim)
937+
938+
objs = list(objs)
939+
if not all(isinstance(obj, DataTree) for obj in objs):
940+
raise TypeError("All objects to concatenate must be DataTree objects")
941+
942+
if compat == "identical":
943+
if any(obj.name != objs[0].name for obj in objs[1:]):
944+
raise ValueError("DataTree names not identical")
945+
946+
dim_in_tree = any(dim_name in node.dims for node in objs[0].subtree)
947+
948+
results = {}
949+
try:
950+
for path, nodes in group_subtrees(*objs):
951+
datasets_to_concat = [node.to_dataset() for node in nodes]
952+
results[path] = _dataset_concat(
953+
datasets_to_concat,
954+
dim=dim,
955+
data_vars=data_vars,
956+
coords=coords,
957+
compat=compat,
958+
positions=positions,
959+
fill_value=fill_value,
960+
join=join,
961+
combine_attrs=combine_attrs,
962+
create_index_for_new_dim=create_index_for_new_dim,
963+
preexisting_dim=dim_in_tree,
964+
)
965+
except TreeIsomorphismError as e:
966+
raise ValueError("All trees must be isomorphic to be concatenated") from e
967+
968+
return DataTree.from_dict(results, name=objs[0].name)

0 commit comments

Comments
 (0)