3030)
3131
3232if TYPE_CHECKING :
33+ from xarray .core .datatree import DataTree
3334 from xarray .core .types import (
3435 CombineAttrsOptions ,
3536 CompatOptions ,
4041 T_DataVars = Union [ConcatOptions , Iterable [Hashable ], None ]
4142
4243
44+ @overload
45+ def concat (
46+ objs : Iterable [DataTree ],
47+ dim : Hashable | T_Variable | T_DataArray | pd .Index | Any ,
48+ data_vars : T_DataVars | CombineKwargDefault = _DATA_VARS_DEFAULT ,
49+ coords : ConcatOptions | Iterable [Hashable ] | CombineKwargDefault = _COORDS_DEFAULT ,
50+ compat : CompatOptions | CombineKwargDefault = _COMPAT_CONCAT_DEFAULT ,
51+ positions : Iterable [Iterable [int ]] | None = None ,
52+ fill_value : object = dtypes .NA ,
53+ join : JoinOptions | CombineKwargDefault = _JOIN_DEFAULT ,
54+ combine_attrs : CombineAttrsOptions = "override" ,
55+ create_index_for_new_dim : bool = True ,
56+ ) -> DataTree : ...
57+
58+
4359# TODO: replace dim: Any by 1D array_likes
4460@overload
4561def concat (
@@ -87,7 +103,7 @@ def concat(
87103
88104 Parameters
89105 ----------
90- objs : sequence of Dataset and DataArray
106+ objs : sequence of DataArray, Dataset or DataTree
91107 xarray objects to concatenate together. Each object is expected to
92108 consist of variables and coordinates with matching shapes except for
93109 along the concatenated dimension.
@@ -117,9 +133,7 @@ def concat(
117133 coords : {"minimal", "different", "all"} or list of Hashable, optional
118134 These coordinate variables will be concatenated together:
119135 * "minimal": Only coordinates in which the dimension already appears
120- are included. If concatenating over a dimension _not_
121- present in any of the objects, then all data variables will
122- be concatenated along that new dimension.
136+ are included.
123137 * "different": Coordinates which are not equal (ignoring attributes)
124138 across all datasets are also concatenated (as well as all for which
125139 dimension already appears). Beware: this option may load the data
@@ -180,7 +194,8 @@ def concat(
180194 If a callable, it must expect a sequence of ``attrs`` dicts and a context object
181195 as its only parameters.
182196 create_index_for_new_dim : bool, default: True
183- Whether to create a new ``PandasIndex`` object when the objects being concatenated contain scalar variables named ``dim``.
197+ Whether to create a new ``PandasIndex`` object when the objects being
198+ concatenated contain scalar variables named ``dim``.
184199
185200 Returns
186201 -------
@@ -265,6 +280,7 @@ def concat(
265280 # dimension already exists
266281 from xarray .core .dataarray import DataArray
267282 from xarray .core .dataset import Dataset
283+ from xarray .core .datatree import DataTree
268284
269285 try :
270286 first_obj , objs = utils .peek_at (objs )
@@ -278,7 +294,20 @@ def concat(
278294 f"compat={ compat !r} invalid: must be 'broadcast_equals', 'equals', 'identical', 'no_conflicts' or 'override'"
279295 )
280296
281- if isinstance (first_obj , DataArray ):
297+ if isinstance (first_obj , DataTree ):
298+ return _datatree_concat (
299+ objs ,
300+ dim = dim ,
301+ data_vars = data_vars ,
302+ coords = coords ,
303+ compat = compat ,
304+ positions = positions ,
305+ fill_value = fill_value ,
306+ join = join ,
307+ combine_attrs = combine_attrs ,
308+ create_index_for_new_dim = create_index_for_new_dim ,
309+ )
310+ elif isinstance (first_obj , DataArray ):
282311 return _dataarray_concat (
283312 objs ,
284313 dim = dim ,
@@ -342,7 +371,7 @@ def _calc_concat_over(
342371 datasets : list [T_Dataset ],
343372 dim : Hashable ,
344373 all_dims : set [Hashable ],
345- data_vars : T_DataVars | CombineKwargDefault ,
374+ data_vars : T_DataVars | Iterable [ Hashable ] | CombineKwargDefault ,
346375 coords : ConcatOptions | Iterable [Hashable ] | CombineKwargDefault ,
347376 compat : CompatOptions | CombineKwargDefault ,
348377) -> tuple [set [Hashable ], dict [Hashable , bool ], list [int ], set [Hashable ]]:
@@ -574,7 +603,7 @@ def _parse_datasets(
574603
575604def _dataset_concat (
576605 datasets : Iterable [T_Dataset ],
577- dim : str | T_Variable | T_DataArray | pd .Index ,
606+ dim : Hashable | T_Variable | T_DataArray | pd .Index ,
578607 data_vars : T_DataVars | CombineKwargDefault ,
579608 coords : ConcatOptions | Iterable [Hashable ] | CombineKwargDefault ,
580609 compat : CompatOptions | CombineKwargDefault ,
@@ -583,6 +612,8 @@ def _dataset_concat(
583612 join : JoinOptions | CombineKwargDefault ,
584613 combine_attrs : CombineAttrsOptions ,
585614 create_index_for_new_dim : bool ,
615+ * ,
616+ preexisting_dim : bool = False ,
586617) -> T_Dataset :
587618 """
588619 Concatenate a sequence of datasets along a new or existing dimension
@@ -618,6 +649,11 @@ def _dataset_concat(
618649 all_dims , dim_coords , dims_sizes , coord_names , data_names , vars_order = (
619650 _parse_datasets (datasets )
620651 )
652+ if preexisting_dim :
653+ # When concatenating DataTree objects, a dimension may be pre-existing
654+ # because it exists elsewhere on the trees, even if it does not exist
655+ # on the dataset objects at this node.
656+ all_dims .add (dim_name )
621657 indexed_dim_names = set (dim_coords )
622658
623659 both_data_and_coords = coord_names & data_names
@@ -818,8 +854,8 @@ def get_indexes(name):
818854
819855def _dataarray_concat (
820856 arrays : Iterable [T_DataArray ],
821- dim : str | T_Variable | T_DataArray | pd .Index ,
822- data_vars : T_DataVars | CombineKwargDefault ,
857+ dim : Hashable | T_Variable | T_DataArray | pd .Index ,
858+ data_vars : T_DataVars | Iterable [ Hashable ] | CombineKwargDefault ,
823859 coords : ConcatOptions | Iterable [Hashable ] | CombineKwargDefault ,
824860 compat : CompatOptions | CombineKwargDefault ,
825861 positions : Iterable [Iterable [int ]] | None ,
@@ -877,3 +913,56 @@ def _dataarray_concat(
877913 result .attrs = merged_attrs
878914
879915 return result
916+
917+
918+ def _datatree_concat (
919+ objs : Iterable [DataTree ],
920+ dim : Hashable | Variable | T_DataArray | pd .Index | Any ,
921+ data_vars : T_DataVars | Iterable [Hashable ] | CombineKwargDefault ,
922+ coords : ConcatOptions | Iterable [Hashable ] | CombineKwargDefault ,
923+ compat : CompatOptions | CombineKwargDefault ,
924+ positions : Iterable [Iterable [int ]] | None ,
925+ fill_value : Any ,
926+ join : JoinOptions | CombineKwargDefault ,
927+ combine_attrs : CombineAttrsOptions ,
928+ create_index_for_new_dim : bool ,
929+ ) -> DataTree :
930+ """
931+ Concatenate a sequence of datatrees along a new or existing dimension
932+ """
933+ from xarray .core .datatree import DataTree
934+ from xarray .core .treenode import TreeIsomorphismError , group_subtrees
935+
936+ dim_name , _ = _calc_concat_dim_index (dim )
937+
938+ objs = list (objs )
939+ if not all (isinstance (obj , DataTree ) for obj in objs ):
940+ raise TypeError ("All objects to concatenate must be DataTree objects" )
941+
942+ if compat == "identical" :
943+ if any (obj .name != objs [0 ].name for obj in objs [1 :]):
944+ raise ValueError ("DataTree names not identical" )
945+
946+ dim_in_tree = any (dim_name in node .dims for node in objs [0 ].subtree )
947+
948+ results = {}
949+ try :
950+ for path , nodes in group_subtrees (* objs ):
951+ datasets_to_concat = [node .to_dataset () for node in nodes ]
952+ results [path ] = _dataset_concat (
953+ datasets_to_concat ,
954+ dim = dim ,
955+ data_vars = data_vars ,
956+ coords = coords ,
957+ compat = compat ,
958+ positions = positions ,
959+ fill_value = fill_value ,
960+ join = join ,
961+ combine_attrs = combine_attrs ,
962+ create_index_for_new_dim = create_index_for_new_dim ,
963+ preexisting_dim = dim_in_tree ,
964+ )
965+ except TreeIsomorphismError as e :
966+ raise ValueError ("All trees must be isomorphic to be concatenated" ) from e
967+
968+ return DataTree .from_dict (results , name = objs [0 ].name )
0 commit comments