Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly keep track of indexes with merging #3234

Merged
merged 9 commits into from
Oct 4, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Typing fixes
shoyer committed Aug 20, 2019
commit 7d8312515dc80a1f0ba02a4e9a23004f6bd13a2b
38 changes: 21 additions & 17 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
@@ -31,13 +31,14 @@
from .dataarray import DataArray
from .dataset import Dataset

DatasetLikeValue = Union[
DataArray, Variable, Tuple[Hashable, Any], Tuple[Sequence[Hashable], Any]
DimsLike = Union[Hashable, Sequence[Hashable]]
VariableTuple = Union[
Tuple[DimsLike, Any],
Tuple[DimsLike, Any, Mapping],
Tuple[DimsLike, Any, Mapping, Mapping],
]
DatasetLikeValue = Union[DataArray, Variable, VariableTuple]
DatasetLike = Union[Dataset, Mapping[Hashable, DatasetLikeValue]]
# Any object type that can be used on the rhs of Dataset.update,
# Dataset.merge, etc.
MutableDatasetLike = Union[Dataset, MutableMapping[Hashable, DatasetLikeValue]]


PANDAS_TYPES = (pd.Series, pd.DataFrame, pdcompat.Panel)
@@ -215,7 +216,7 @@ def merge_collected(


def collect_variables_and_indexes(
list_of_mappings: "List[Union[Dataset, OrderedDict]]",
list_of_mappings: "List[DatasetLike]",
) -> "OrderedDict[Hashable, List[MergeElement]]":
"""Collect variables and indexes from list of mappings of xarray objects.

@@ -334,12 +335,12 @@ def determine_coords(
coord_names = set() # type: set
noncoord_names = set() # type: set

for variables in list_of_mappings:
if isinstance(variables, Dataset):
coord_names.update(variables.coords)
noncoord_names.update(variables.data_vars)
for mapping in list_of_mappings:
if isinstance(mapping, Dataset):
coord_names.update(mapping.coords)
noncoord_names.update(mapping.data_vars)
else:
for name, var in variables.items():
for name, var in mapping.items():
if isinstance(var, DataArray):
coords = set(var._coords) # use private API for speed
# explicitly overwritten variables should take precedence
@@ -382,7 +383,9 @@ def coerce_pandas_values(objects: Iterable["DatasetLike"]) -> List["DatasetLike"
return out


def _get_priority_vars_and_indexes(objects, priority_arg, compat="equals"):
def _get_priority_vars_and_indexes(
objects: List["DatasetLike"], priority_arg: Optional[int], compat: str = "equals"
) -> "OrderedDict[Hashable, MergeElement]":
"""Extract the priority variable from a list of mappings.

We need this method because in some cases the priority argument itself
@@ -400,15 +403,14 @@ def _get_priority_vars_and_indexes(objects, priority_arg, compat="equals"):

Returns
-------
None, if priority_arg is None, or an OrderedDict with Variable objects as
values indicating priority variables.
An OrderedDict of variables and associated indexes (if any) to prioritize.
""" # noqa
if priority_arg is None:
return OrderedDict()

collected = collect_variables_and_indexes([objects[priority_arg]])
variables, indexes = merge_collected(collected, compat=compat)
grouped = OrderedDict()
grouped = OrderedDict() # type: OrderedDict[Hashable, MergeElement]
for name, variable in variables.items():
grouped[name] = (variable, indexes.get(name))
return grouped
@@ -669,8 +671,10 @@ def dataset_merge_method(
objs = [dataset, other]
priority_arg = 1
else:
other_overwrite = OrderedDict() # type: MutableDatasetLike
other_no_overwrite = OrderedDict() # type: MutableDatasetLike
other_overwrite = OrderedDict() # type: OrderedDict[Hashable, DatasetLikeValue]
other_no_overwrite = (
OrderedDict()
) # type: OrderedDict[Hashable, DatasetLikeValue]
for k, v in other.items():
if k in overwrite_vars:
other_overwrite[k] = v