From cb5ce9ac6f74e14812afe7125a9ffaddf2f53afc Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 15 May 2022 18:14:36 +0200 Subject: [PATCH 01/45] type filename and chunks --- xarray/backends/api.py | 61 ++++++++++++++++++++++++++------------- xarray/backends/common.py | 8 +++-- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 95f8dbc6eaf..b13c9084095 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1,18 +1,24 @@ +from __future__ import annotations + import os from glob import glob from io import BytesIO from numbers import Number from typing import ( TYPE_CHECKING, + Any, Callable, Dict, + Final, Hashable, Iterable, Mapping, MutableMapping, Optional, Tuple, + Type, Union, + Literal, ) import numpy as np @@ -36,6 +42,7 @@ from dask.delayed import Delayed except ImportError: Delayed = None # type: ignore + from .common import BackendEntrypoint DATAARRAY_NAME = "__xarray_dataarray_name__" @@ -52,8 +59,24 @@ "zarr": backends.ZarrStore.open_group, } - -def _get_default_engine_remote_uri(): +T_ENGINE = Union[ + Literal[ + "netcdf4", + "scipy", + "pydap", + "h5netcdf", + "pynio", + "pseudonetcdf", + "cfgrib", + "zarr", + ], + None, + Type[BackendEntrypoint], +] +T_CHUNKS = Union[int, dict[Any, Any], None, Literal["auto"]] + +def _get_default_engine_remote_uri() -> Literal["netcdf4", "pydap"]: + engine: Literal["netcdf4", "pydap"] try: import netCDF4 # noqa: F401 @@ -71,17 +94,18 @@ def _get_default_engine_remote_uri(): return engine -def _get_default_engine_gz(): +def _get_default_engine_gz() -> Literal["scipy"]: try: import scipy # noqa: F401 - engine = "scipy" + engine: Final = "scipy" except ImportError: # pragma: no cover raise ValueError("scipy is required for accessing .gz files") return engine -def _get_default_engine_netcdf(): +def _get_default_engine_netcdf() -> Literal["netcdf4", "scipy"]: + engine: Literal["netcdf4", "scipy"] try: import netCDF4 # noqa: F401 @@ -99,7 +123,9 @@ def _get_default_engine_netcdf(): return engine -def _get_default_engine(path: str, allow_remote: bool = False): +def _get_default_engine( + path: str, allow_remote: bool = False +) -> Literal["netcdf4", "scipy", "pydap"]: if allow_remote and is_remote_uri(path): return _get_default_engine_remote_uri() elif path.endswith(".gz"): @@ -108,10 +134,10 @@ def _get_default_engine(path: str, allow_remote: bool = False): return _get_default_engine_netcdf() -def _validate_dataset_names(dataset): +def _validate_dataset_names(dataset: Dataset) -> None: """DataArray.name and Dataset keys must be a string or None""" - def check_name(name): + def check_name(name: Hashable): if isinstance(name, str): if not name: raise ValueError( @@ -216,7 +242,7 @@ def _finalize_store(write, store): store.close() -def load_dataset(filename_or_obj, **kwargs): +def load_dataset(filename_or_obj, **kwargs) -> Dataset: """Open, load into memory, and close a Dataset from a file or file-like object. @@ -337,10 +363,10 @@ def _dataset_from_backend_dataset( def open_dataset( - filename_or_obj, - *args, - engine=None, - chunks=None, + filename_or_obj: str | os.PathLike, + *, + engine: T_ENGINE = None, + chunks: T_CHUNKS = None, cache=None, decode_cf=None, mask_and_scale=None, @@ -353,7 +379,7 @@ def open_dataset( inline_array=False, backend_kwargs=None, **kwargs, -): +) -> Dataset: """Open and decode a dataset from a file or file-like object. Parameters @@ -370,7 +396,7 @@ def open_dataset( is chosen based on available dependencies, with a preference for "netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``) can also be used. - chunks : int or dict, optional + chunks : int, dict or 'auto', optional If chunks is provided, it is used to load the new dataset into dask arrays. ``chunks=-1`` loads the dataset with dask using a single chunk for all arrays. ``chunks={}`` loads the dataset with dask using @@ -474,11 +500,6 @@ def open_dataset( -------- open_mfdataset """ - if len(args) > 0: - raise TypeError( - "open_dataset() takes only 1 positional argument starting from version 0.18.0, " - "all other options must be passed as keyword arguments" - ) if cache is None: cache = chunks is None diff --git a/xarray/backends/common.py b/xarray/backends/common.py index ad92a6c5869..42b5f8aedac 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import time @@ -374,8 +376,8 @@ class BackendEntrypoint: def open_dataset( self, - filename_or_obj: str, - drop_variables: Tuple[str] = None, + filename_or_obj: str | os.PathLike, + drop_variables: Tuple[str] | None = None, **kwargs: Any, ): """ @@ -384,7 +386,7 @@ def open_dataset( raise NotImplementedError - def guess_can_open(self, filename_or_obj): + def guess_can_open(self, filename_or_obj: str | os.PathLike): """ Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. """ From e0a3c4669b1223ac5b3721953521371f2ba060f5 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 15 May 2022 19:21:57 +0200 Subject: [PATCH 02/45] type open_dataset, open_dataarray, open_mfdataset --- xarray/backends/api.py | 139 ++++++++++++++++++------------- xarray/core/combine.py | 10 ++- xarray/core/computation.py | 21 ++++- xarray/core/dataset.py | 18 ++-- xarray/core/merge.py | 22 +++-- xarray/tests/test_distributed.py | 4 +- 6 files changed, 138 insertions(+), 76 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index b13c9084095..76660825bb2 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -7,6 +7,7 @@ from typing import ( TYPE_CHECKING, Any, + cast, Callable, Dict, Final, @@ -15,9 +16,11 @@ Mapping, MutableMapping, Optional, + Sequence, Tuple, Type, Union, + List, Literal, ) @@ -32,6 +35,7 @@ ) from ..core.dataarray import DataArray from ..core.dataset import Dataset, _get_chunk, _maybe_chunk +from ..core.indexes import Index from ..core.utils import is_remote_uri from . import plugins from .common import AbstractDataStore, ArrayWriter, _normalize_path @@ -367,17 +371,17 @@ def open_dataset( *, engine: T_ENGINE = None, chunks: T_CHUNKS = None, - cache=None, - decode_cf=None, - mask_and_scale=None, - decode_times=None, - decode_timedelta=None, - use_cftime=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - inline_array=False, - backend_kwargs=None, + cache: bool | None = None, + decode_cf: bool | None = None, + mask_and_scale: bool | None = None, + decode_times: bool | None = None, + decode_timedelta: bool | None = None, + use_cftime: bool | None = None, + concat_characters: bool | None = None, + decode_coords: Literal["coordinates", "all"] | bool | None = None, + drop_variables: str | Iterable[str] | None = None, + inline_array: bool = False, + backend_kwargs: Dict[str, Any] | None = None, **kwargs, ) -> Dataset: """Open and decode a dataset from a file or file-like object. @@ -396,7 +400,7 @@ def open_dataset( is chosen based on available dependencies, with a preference for "netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``) can also be used. - chunks : int, dict or 'auto', optional + chunks : int, dict, 'auto' or None, optional If chunks is provided, it is used to load the new dataset into dask arrays. ``chunks=-1`` loads the dataset with dask using a single chunk for all arrays. ``chunks={}`` loads the dataset with dask using @@ -457,11 +461,11 @@ def open_dataset( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. - drop_variables: str or iterable, optional + drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. - inline_array: bool, optional + inline_array: bool, default: False How to include the array in the dask task graph. By default(``inline_array=False``) the array is included in a task by itself, and each chunk refers to that task by its key. With @@ -546,23 +550,23 @@ def open_dataset( def open_dataarray( - filename_or_obj, - *args, - engine=None, - chunks=None, - cache=None, - decode_cf=None, - mask_and_scale=None, - decode_times=None, - decode_timedelta=None, - use_cftime=None, - concat_characters=None, - decode_coords=None, - drop_variables=None, - inline_array=False, - backend_kwargs=None, + filename_or_obj: str | os.PathLike, + *, + engine: T_ENGINE = None, + chunks: T_CHUNKS = None, + cache: bool | None = None, + decode_cf: bool | None = None, + mask_and_scale: bool | None = None, + decode_times: bool | None = None, + decode_timedelta: bool | None = None, + use_cftime: bool | None = None, + concat_characters: bool | None = None, + decode_coords: Literal["coordinates", "all"] | bool | None = None, + drop_variables: str | Iterable[str] | None = None, + inline_array: bool = False, + backend_kwargs: Dict[str, Any] | None = None, **kwargs, -): +) -> DataArray: """Open an DataArray from a file or file-like object containing a single data variable. @@ -582,7 +586,7 @@ def open_dataarray( Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for "netcdf4". - chunks : int or dict, optional + chunks : int, dict, 'auto' or None, optional If chunks is provided, it is used to load the new dataset into dask arrays. ``chunks=-1`` loads the dataset with dask using a single chunk for all arrays. `chunks={}`` loads the dataset with dask using @@ -643,11 +647,11 @@ def open_dataarray( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. - drop_variables: str or iterable, optional + drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. - inline_array: bool, optional + inline_array: bool, default: False How to include the array in the dask task graph. By default(``inline_array=False``) the array is included in a task by itself, and each chunk refers to that task by its key. With @@ -683,11 +687,6 @@ def open_dataarray( -------- open_dataset """ - if len(args) > 0: - raise TypeError( - "open_dataarray() takes only 1 positional argument starting from version 0.18.0, " - "all other options must be passed as keyword arguments" - ) dataset = open_dataset( filename_or_obj, @@ -731,21 +730,25 @@ def open_dataarray( def open_mfdataset( - paths, - chunks=None, - concat_dim=None, - compat="no_conflicts", - preprocess=None, - engine=None, - data_vars="all", + paths: str | Iterable[str | os.PathLike], + chunks: T_CHUNKS = None, + concat_dim: str | DataArray | Index | Sequence[str] | Sequence[DataArray] | Sequence[Index] | None = None, + compat: Literal[ + "identical", "equals", "broadcast_equals", "no_conflicts", "override" + ] = "no_conflicts", + preprocess: Callable[[Dataset], Dataset] | None = None, + engine: T_ENGINE = None, + data_vars: Literal["all", "minimal", "different"] | List[str] = "all", coords="different", - combine="by_coords", - parallel=False, - join="outer", - attrs_file=None, - combine_attrs="override", + combine: Literal["by_coords", "nested"] = "by_coords", + parallel: bool = False, + join: Literal["outer", "inner", "left", "right", "exact", "override"] = "outer", + attrs_file: str | os.PathLike | None = None, + combine_attrs: Literal[ + "drop", "identical", "no_conflicts", "drop_conflicts", "override" + ] | Callable[..., Any] = "override", **kwargs, -): +) -> Dataset: """Open multiple files as a single dataset. If combine='by_coords' then the function ``combine_by_coords`` is used to combine @@ -759,19 +762,19 @@ def open_mfdataset( Parameters ---------- - paths : str or sequence + paths : str or Iterable of paths Either a string glob in the form ``"path/to/my/files/*.nc"`` or an explicit list of files to open. Paths can be given as strings or as pathlib Paths. If concatenation along more than one dimension is desired, then ``paths`` must be a nested list-of-lists (see ``combine_nested`` for details). (A string glob will be expanded to a 1-dimensional list.) - chunks : int or dict, optional + chunks : int, dict, 'auto' or None, optional Dictionary with keys given by dimension names and values given by chunk sizes. In general, these should divide the dimensions of each dataset. If int, chunk each dimension by ``chunks``. By default, chunks will be chosen to load entire input files into memory at once. This has a major impact on performance: please see the full documentation for more details [2]_. - concat_dim : str, or list of str, DataArray, Index or None, optional + concat_dim : str, DataArray, Index or a Sequence of these or None, optional Dimensions to concatenate files along. You only need to provide this argument if ``combine='nested'``, and if any of the dimensions along which you want to concatenate is not a dimension in the original datasets, e.g., if you want to @@ -784,7 +787,7 @@ def open_mfdataset( Whether ``xarray.combine_by_coords`` or ``xarray.combine_nested`` is used to combine all the data. Default is to use ``xarray.combine_by_coords``. compat : {"identical", "equals", "broadcast_equals", \ - "no_conflicts", "override"}, optional + "no_conflicts", "override"}, default: "no_conflicts" String indicating how to compare variables of the same name for potential conflicts when merging: @@ -807,7 +810,7 @@ def open_mfdataset( Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for "netcdf4". - data_vars : {"minimal", "different", "all"} or list of str, optional + data_vars : {"minimal", "different", "all"} or list of str, default: "all" These data variables will be concatenated together: * "minimal": Only data variables in which the dimension already appears are included. @@ -832,10 +835,10 @@ def open_mfdataset( those corresponding to other dimensions. * list of str: The listed coordinate variables will be concatenated, in addition the "minimal" coordinates. - parallel : bool, optional + parallel : bool, default: False If True, the open and preprocess steps of this function will be performed in parallel using ``dask.delayed``. Default is False. - join : {"outer", "inner", "left", "right", "exact, "override"}, optional + join : {"outer", "inner", "left", "right", "exact", "override"}, default: "outer" String indicating how to combine differing indexes (excluding concat_dim) in objects @@ -852,6 +855,22 @@ def open_mfdataset( Path of the file used to read global attributes from. By default global attributes are read from the first file provided, with wildcard matches sorted by filename. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + A callable or a string indicating how to combine attrs of the objects being + merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. **kwargs : optional Additional arguments passed on to :py:func:`xarray.open_dataset`. @@ -915,7 +934,7 @@ def open_mfdataset( if combine == "nested": if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: - concat_dim = [concat_dim] + concat_dim = [concat_dim] # type: ignore[assignment] # This creates a flat list which is easier to iterate over, whilst # encoding the originally-supplied structure as "ids". @@ -1001,7 +1020,7 @@ def multi_file_closer(): # read global attributes from the attrs_file or from the first dataset if attrs_file is not None: if isinstance(attrs_file, os.PathLike): - attrs_file = os.fspath(attrs_file) + attrs_file = cast(str, os.fspath(attrs_file)) combined.attrs = datasets[paths.index(attrs_file)].attrs return combined diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 78f016fdccd..e766a1ba58a 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import itertools import warnings from collections import Counter -from typing import Iterable, Sequence, Union +from typing import Any, Callable, Iterable, Sequence, Union, Literal, List import pandas as pd @@ -661,11 +663,13 @@ def _combine_single_variable_hypercube( def combine_by_coords( data_objects: Sequence[Union[Dataset, DataArray]] = [], compat: str = "no_conflicts", - data_vars: str = "all", + data_vars: Literal["all", "minimal", "different"] | List[str] = "all", coords: str = "different", fill_value: object = dtypes.NA, join: str = "outer", - combine_attrs: str = "no_conflicts", + combine_attrs: Literal[ + "drop", "identical", "no_conflicts", "drop_conflicts", "override" + ] | Callable[..., Any] = "no_conflicts", datasets: Sequence[Dataset] = None, ) -> Union[Dataset, DataArray]: """ diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 81b5e3fd915..a52e6fa7756 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -15,6 +15,7 @@ Callable, Hashable, Iterable, + Literal, Mapping, Sequence, overload, @@ -212,7 +213,9 @@ def build_output_coords_and_indexes( args: list, signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset(), - combine_attrs: str = "override", + combine_attrs: Literal[ + "drop", "identical", "no_conflicts", "drop_conflicts", "override" + ] | Callable[..., Any] = "override", ) -> tuple[list[dict[Any, Variable]], list[dict[Any, Index]]]: """Build output coordinates and indexes for an operation. @@ -226,6 +229,22 @@ def build_output_coords_and_indexes( exclude_dims : set, optional Dimensions excluded from the operation. Coordinates along these dimensions are dropped. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "drop" + A callable or a string indicating how to combine attrs of the objects being + merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. Returns ------- diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b73cd797a8f..d094f47b9b8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4439,7 +4439,9 @@ def merge( compat: str = "no_conflicts", join: str = "outer", fill_value: Any = dtypes.NA, - combine_attrs: str = "override", + combine_attrs: Literal[ + "drop", "identical", "no_conflicts", "drop_conflicts", "override" + ] | Callable[..., Any] = "override", ) -> Dataset: """Merge the arrays of two datasets into a single dataset. @@ -4480,17 +4482,21 @@ def merge( Value to use for newly missing values. If a dict-like, maps variable names (including coordinates) to fill values. combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ - "override"}, default: "override" - String indicating how to combine attrs of the objects being merged: + "override"} or callable, default: "drop" + A callable or a string indicating how to combine attrs of the objects being + merged: - "drop": empty attrs on returned Dataset. - "identical": all attrs must be the same on every object. - "no_conflicts": attrs from all objects are combined, any that have - the same name must also have the same value. + the same name must also have the same value. - "drop_conflicts": attrs from all objects are combined, any that have - the same name but different values are dropped. + the same name but different values are dropped. - "override": skip comparing and copy attrs from the first dataset to - the result. + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. Returns ------- diff --git a/xarray/core/merge.py b/xarray/core/merge.py index b428d4ae958..53c3b1f5f71 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -5,6 +5,7 @@ TYPE_CHECKING, AbstractSet, Any, + Callable, Hashable, Iterable, Mapping, @@ -12,6 +13,7 @@ Optional, Sequence, Tuple, + Literal, Union, ) @@ -208,7 +210,9 @@ def merge_collected( grouped: dict[Hashable, list[MergeElement]], prioritized: Mapping[Any, MergeElement] = None, compat: str = "minimal", - combine_attrs: str | None = "override", + combine_attrs: Literal[ + "drop", "identical", "no_conflicts", "drop_conflicts", "override" + ] | Callable[..., Any] = "override", equals: dict[Hashable, bool] = None, ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. @@ -376,7 +380,9 @@ def merge_coordinates_without_align( objects: list[Coordinates], prioritized: Mapping[Any, MergeElement] = None, exclude_dims: AbstractSet = frozenset(), - combine_attrs: str = "override", + combine_attrs: Literal[ + "drop", "identical", "no_conflicts", "drop_conflicts", "override" + ] | Callable[..., Any] = "override", ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge variables/indexes from coordinates without automatic alignments. @@ -667,7 +673,9 @@ def merge_core( objects: Iterable[CoercibleMapping], compat: str = "broadcast_equals", join: str = "outer", - combine_attrs: str | None = "override", + combine_attrs: Literal[ + "drop", "identical", "no_conflicts", "drop_conflicts", "override" + ] | Callable[..., Any] = "override", priority_arg: int | None = None, explicit_coords: Sequence | None = None, indexes: Mapping[Any, Any] | None = None, @@ -757,7 +765,9 @@ def merge( compat: str = "no_conflicts", join: str = "outer", fill_value: object = dtypes.NA, - combine_attrs: str = "override", + combine_attrs: Literal[ + "drop", "identical", "no_conflicts", "drop_conflicts", "override" + ] | Callable[..., Any] = "override", ) -> Dataset: """Merge any number of xarray objects into a single Dataset as variables. @@ -1005,7 +1015,9 @@ def dataset_merge_method( compat: str, join: str, fill_value: Any, - combine_attrs: str, + combine_attrs: Literal[ + "drop", "identical", "no_conflicts", "drop_conflicts", "override" + ] | Callable[..., Any], ) -> _MergeResult: """Guts of the Dataset.merge method.""" # we are locked into supporting overwrite_vars for the Dataset.merge diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 773733b7b89..8ac87bbc807 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -2,6 +2,8 @@ import pickle import numpy as np +from typing import Any, Dict + import pytest from packaging.version import Version @@ -156,7 +158,7 @@ def test_dask_distributed_zarr_integration_test(loop, consolidated, compute) -> if consolidated: pytest.importorskip("zarr", minversion="2.2.1.dev2") write_kwargs = {"consolidated": True} - read_kwargs = {"backend_kwargs": {"consolidated": True}} + read_kwargs: Dict[str, Any] = {"backend_kwargs": {"consolidated": True}} else: write_kwargs = read_kwargs = {} # type: ignore chunks = {"dim1": 4, "dim2": 3, "dim3": 5} From 4a02e8768833a605867f72c7a2ddf2645730ef79 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 15 May 2022 20:03:31 +0200 Subject: [PATCH 03/45] type to_netcdf --- xarray/backends/api.py | 116 ++++++++++++++++++++++++++++++++--------- xarray/core/dataset.py | 66 +++++++++++++++++++---- 2 files changed, 148 insertions(+), 34 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 76660825bb2..287443fab5c 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -22,10 +22,13 @@ Union, List, Literal, + overload, ) import numpy as np +from xarray.backends.zarr import ZarrStore + from .. import backends, conventions from ..core import indexing from ..core.combine import ( @@ -63,21 +66,16 @@ "zarr": backends.ZarrStore.open_group, } +T_NETCDFENGINE = Literal["netcdf4", "scipy", "h5netcdf"] T_ENGINE = Union[ - Literal[ - "netcdf4", - "scipy", - "pydap", - "h5netcdf", - "pynio", - "pseudonetcdf", - "cfgrib", - "zarr", - ], - None, + T_NETCDFENGINE, + Literal["pydap", "pynio", "pseudonetcdf", "cfgrib", "zarr"], Type[BackendEntrypoint], ] -T_CHUNKS = Union[int, dict[Any, Any], None, Literal["auto"]] +T_CHUNKS = Union[int, dict[Any, Any], Literal["auto"], None] +T_NETCDFTYPES = Literal[ + "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" +] def _get_default_engine_remote_uri() -> Literal["netcdf4", "pydap"]: engine: Literal["netcdf4", "pydap"] @@ -129,9 +127,9 @@ def _get_default_engine_netcdf() -> Literal["netcdf4", "scipy"]: def _get_default_engine( path: str, allow_remote: bool = False -) -> Literal["netcdf4", "scipy", "pydap"]: +) -> T_NETCDFENGINE: if allow_remote and is_remote_uri(path): - return _get_default_engine_remote_uri() + return _get_default_engine_remote_uri() # type: ignore[return-value] elif path.endswith(".gz"): return _get_default_engine_gz() else: @@ -1033,19 +1031,87 @@ def multi_file_closer(): } +@overload def to_netcdf( dataset: Dataset, - path_or_file=None, - mode: str = "w", - format: str = None, - group: str = None, - engine: str = None, - encoding: Mapping = None, - unlimited_dims: Iterable[Hashable] = None, + path_or_file: str | os.PathLike | None, + mode: Literal["w", "a"], + format: T_NETCDFTYPES | None, + group: str | None, + engine: T_NETCDFENGINE | None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None, + unlimited_dims: Iterable[Hashable] | None, + compute: bool, + multifile: Literal[True], + invalid_netcdf: bool, +) -> Tuple[ArrayWriter, AbstractDataStore]: + ... + + +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: None, + mode: Literal["w", "a"], + format: T_NETCDFTYPES | None, + group: str | None, + engine: T_NETCDFENGINE | None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None, + unlimited_dims: Iterable[Hashable] | None, + compute: bool, + multifile: Literal[False], + invalid_netcdf: bool, +) -> bytes: + ... + + +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: Literal["w", "a"], + format: T_NETCDFTYPES | None, + group: str | None, + engine: T_NETCDFENGINE | None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None, + unlimited_dims: Iterable[Hashable] | None, + compute: Literal[False], + multifile: Literal[False], + invalid_netcdf: bool, +) -> Delayed: + ... + + +@overload +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike, + mode: Literal["w", "a"], + format: T_NETCDFTYPES | None, + group: str | None, + engine: T_NETCDFENGINE | None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None, + unlimited_dims: Iterable[Hashable] | None, + compute: Literal[True], + multifile: Literal[False], + invalid_netcdf: bool, +) -> None: + ... + + +def to_netcdf( + dataset: Dataset, + path_or_file: str | os.PathLike | None = None, + mode: Literal["w", "a"] = "w", + format: T_NETCDFTYPES | None = None, + group: str | None = None, + engine: T_NETCDFENGINE | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, compute: bool = True, multifile: bool = False, invalid_netcdf: bool = False, -) -> Union[Tuple[ArrayWriter, AbstractDataStore], bytes, "Delayed", None]: +) -> Tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file @@ -1090,7 +1156,7 @@ def to_netcdf( raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") if format is not None: - format = format.upper() + format = format.upper() # type: ignore[assignment] # handle scheduler specific logic scheduler = _get_scheduler() @@ -1140,7 +1206,7 @@ def to_netcdf( writes = writer.sync(compute=compute) - if path_or_file is None: + if isinstance(target, BytesIO): store.sync() return target.getvalue() finally: @@ -1387,7 +1453,7 @@ def to_zarr( region: Mapping[str, slice] = None, safe_chunks: bool = True, storage_options: Dict[str, str] = None, -): +) -> ZarrStore | Delayed: """This function creates an appropriate datastore for writing a dataset to a zarr ztore diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d094f47b9b8..abd93fb5822 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -23,6 +23,7 @@ Mapping, MutableMapping, Sequence, + Tuple, cast, overload, ) @@ -103,7 +104,9 @@ ) if TYPE_CHECKING: + import os from ..backends import AbstractDataStore, ZarrStore + from ..backends.api import T_NETCDFTYPES, T_NETCDFENGINE from .dataarray import DataArray from .merge import CoercibleMapping from .types import ErrorChoice, ErrorChoiceWithWarn, T_Xarray @@ -1677,18 +1680,63 @@ def dump_to_store(self, store: AbstractDataStore, **kwargs) -> None: # with to_netcdf() dump_to_store(self, store, **kwargs) + @overload def to_netcdf( self, - path=None, - mode: str = "w", - format: str = None, - group: str = None, - engine: str = None, - encoding: Mapping = None, - unlimited_dims: Iterable[Hashable] = None, + path: None, + mode: Literal["w", "a"], + format: T_NETCDFTYPES | None, + group: str | None, + engine: T_NETCDFENGINE | None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None, + unlimited_dims: Iterable[Hashable] | None, + compute: bool, + invalid_netcdf: bool, + ) -> bytes: + ... + + @overload + def to_netcdf( + self, + path: str | os.PathLike, + mode: Literal["w", "a"], + format: T_NETCDFTYPES | None, + group: str | None, + engine: T_NETCDFENGINE | None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None, + unlimited_dims: Iterable[Hashable] | None, + compute: Literal[False], + invalid_netcdf: bool, + ) -> Delayed: + ... + + @overload + def to_netcdf( + self, + path: str | os.PathLike, + mode: Literal["w", "a"], + format: T_NETCDFTYPES | None, + group: str | None, + engine: T_NETCDFENGINE | None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None, + unlimited_dims: Iterable[Hashable] | None, + compute: Literal[True], + invalid_netcdf: bool, + ) -> None: + ... + + def to_netcdf( + self, + path: str | os.PathLike | None = None, + mode: Literal["w", "a"] = "w", + format: T_NETCDFTYPES | None = None, + group: str | None = None, + engine: T_NETCDFENGINE | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, compute: bool = True, invalid_netcdf: bool = False, - ) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: + ) -> bytes | Delayed | None: """Write dataset contents to a netCDF file. Parameters @@ -1760,7 +1808,7 @@ def to_netcdf( encoding = {} from ..backends.api import to_netcdf - return to_netcdf( + return to_netcdf( # type: ignore # mypy cannot resolve the overloads:( self, path, mode, From 27daca444feaca5cbbf991717cb64369fa31bebc Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 15 May 2022 20:05:39 +0200 Subject: [PATCH 04/45] add return doc to Dataset.to_netcdf --- xarray/core/dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index abd93fb5822..d40511175ff 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1803,6 +1803,12 @@ def to_netcdf( Only valid along with ``engine="h5netcdf"``. If True, allow writing hdf5 files which are invalid netcdf as described in https://github.com/h5netcdf/h5netcdf. + + Returns + ------- + * ``bytes`` if path is None + * ``dask.delayed.Delayed`` if compute is False + * None otherwise """ if encoding is None: encoding = {} From 664e35a7bc6d1fac883e6e156245584f32e8d71e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 May 2022 18:10:54 +0000 Subject: [PATCH 05/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/backends/api.py | 42 ++++++++++++++++++++++---------------- xarray/backends/common.py | 6 +++--- xarray/core/combine.py | 17 +++++++-------- xarray/core/computation.py | 3 ++- xarray/core/dataset.py | 7 ++++--- xarray/core/merge.py | 17 +++++++++------ 6 files changed, 53 insertions(+), 39 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 287443fab5c..9e3b9da547c 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -7,12 +7,13 @@ from typing import ( TYPE_CHECKING, Any, - cast, Callable, Dict, Final, Hashable, Iterable, + List, + Literal, Mapping, MutableMapping, Optional, @@ -20,8 +21,7 @@ Tuple, Type, Union, - List, - Literal, + cast, overload, ) @@ -76,7 +76,8 @@ T_NETCDFTYPES = Literal[ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" ] - + + def _get_default_engine_remote_uri() -> Literal["netcdf4", "pydap"]: engine: Literal["netcdf4", "pydap"] try: @@ -125,9 +126,7 @@ def _get_default_engine_netcdf() -> Literal["netcdf4", "scipy"]: return engine -def _get_default_engine( - path: str, allow_remote: bool = False -) -> T_NETCDFENGINE: +def _get_default_engine(path: str, allow_remote: bool = False) -> T_NETCDFENGINE: if allow_remote and is_remote_uri(path): return _get_default_engine_remote_uri() # type: ignore[return-value] elif path.endswith(".gz"): @@ -379,7 +378,7 @@ def open_dataset( decode_coords: Literal["coordinates", "all"] | bool | None = None, drop_variables: str | Iterable[str] | None = None, inline_array: bool = False, - backend_kwargs: Dict[str, Any] | None = None, + backend_kwargs: dict[str, Any] | None = None, **kwargs, ) -> Dataset: """Open and decode a dataset from a file or file-like object. @@ -562,7 +561,7 @@ def open_dataarray( decode_coords: Literal["coordinates", "all"] | bool | None = None, drop_variables: str | Iterable[str] | None = None, inline_array: bool = False, - backend_kwargs: Dict[str, Any] | None = None, + backend_kwargs: dict[str, Any] | None = None, **kwargs, ) -> DataArray: """Open an DataArray from a file or file-like object containing a single @@ -730,13 +729,19 @@ def open_dataarray( def open_mfdataset( paths: str | Iterable[str | os.PathLike], chunks: T_CHUNKS = None, - concat_dim: str | DataArray | Index | Sequence[str] | Sequence[DataArray] | Sequence[Index] | None = None, + concat_dim: str + | DataArray + | Index + | Sequence[str] + | Sequence[DataArray] + | Sequence[Index] + | None = None, compat: Literal[ "identical", "equals", "broadcast_equals", "no_conflicts", "override" ] = "no_conflicts", preprocess: Callable[[Dataset], Dataset] | None = None, engine: T_ENGINE = None, - data_vars: Literal["all", "minimal", "different"] | List[str] = "all", + data_vars: Literal["all", "minimal", "different"] | list[str] = "all", coords="different", combine: Literal["by_coords", "nested"] = "by_coords", parallel: bool = False, @@ -744,7 +749,8 @@ def open_mfdataset( attrs_file: str | os.PathLike | None = None, combine_attrs: Literal[ "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] | Callable[..., Any] = "override", + ] + | Callable[..., Any] = "override", **kwargs, ) -> Dataset: """Open multiple files as a single dataset. @@ -1024,7 +1030,7 @@ def multi_file_closer(): return combined -WRITEABLE_STORES: Dict[str, Callable] = { +WRITEABLE_STORES: dict[str, Callable] = { "netcdf4": backends.NetCDF4DataStore.open, "scipy": backends.ScipyDataStore, "h5netcdf": backends.H5NetCDFStore.open, @@ -1044,7 +1050,7 @@ def to_netcdf( compute: bool, multifile: Literal[True], invalid_netcdf: bool, -) -> Tuple[ArrayWriter, AbstractDataStore]: +) -> tuple[ArrayWriter, AbstractDataStore]: ... @@ -1111,7 +1117,7 @@ def to_netcdf( compute: bool = True, multifile: bool = False, invalid_netcdf: bool = False, -) -> Tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: +) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file @@ -1441,18 +1447,18 @@ def check_dtype(vname, var): def to_zarr( dataset: Dataset, - store: Union[MutableMapping, str, os.PathLike] = None, + store: MutableMapping | str | os.PathLike = None, chunk_store=None, mode: str = None, synchronizer=None, group: str = None, encoding: Mapping = None, compute: bool = True, - consolidated: Optional[bool] = None, + consolidated: bool | None = None, append_dim: Hashable = None, region: Mapping[str, slice] = None, safe_chunks: bool = True, - storage_options: Dict[str, str] = None, + storage_options: dict[str, str] = None, ) -> ZarrStore | Delayed: """This function creates an appropriate datastore for writing a dataset to a zarr ztore diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 42b5f8aedac..170ec47e510 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -371,13 +371,13 @@ class BackendEntrypoint: method is not mandatory. """ - open_dataset_parameters: Union[Tuple, None] = None + open_dataset_parameters: tuple | None = None """list of ``open_dataset`` method parameters""" def open_dataset( self, filename_or_obj: str | os.PathLike, - drop_variables: Tuple[str] | None = None, + drop_variables: tuple[str] | None = None, **kwargs: Any, ): """ @@ -394,4 +394,4 @@ def guess_can_open(self, filename_or_obj: str | os.PathLike): return False -BACKEND_ENTRYPOINTS: Dict[str, Type[BackendEntrypoint]] = {} +BACKEND_ENTRYPOINTS: dict[str, type[BackendEntrypoint]] = {} diff --git a/xarray/core/combine.py b/xarray/core/combine.py index e766a1ba58a..6260fd01a3e 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -3,7 +3,7 @@ import itertools import warnings from collections import Counter -from typing import Any, Callable, Iterable, Sequence, Union, Literal, List +from typing import Any, Callable, Iterable, List, Literal, Sequence, Union import pandas as pd @@ -379,9 +379,9 @@ def _nested_combine( def combine_nested( datasets: DATASET_HYPERCUBE, - concat_dim: Union[ - str, DataArray, None, Sequence[Union[str, "DataArray", pd.Index, None]] - ], + concat_dim: ( + str | DataArray | None | Sequence[Union[str, DataArray, pd.Index, None]] + ), compat: str = "no_conflicts", data_vars: str = "all", coords: str = "different", @@ -661,17 +661,18 @@ def _combine_single_variable_hypercube( # TODO remove empty list default param after version 0.21, see PR4696 def combine_by_coords( - data_objects: Sequence[Union[Dataset, DataArray]] = [], + data_objects: Sequence[Dataset | DataArray] = [], compat: str = "no_conflicts", - data_vars: Literal["all", "minimal", "different"] | List[str] = "all", + data_vars: Literal["all", "minimal", "different"] | list[str] = "all", coords: str = "different", fill_value: object = dtypes.NA, join: str = "outer", combine_attrs: Literal[ "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] | Callable[..., Any] = "no_conflicts", + ] + | Callable[..., Any] = "no_conflicts", datasets: Sequence[Dataset] = None, -) -> Union[Dataset, DataArray]: +) -> Dataset | DataArray: """ Attempt to auto-magically combine the given datasets (or data arrays) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a52e6fa7756..caa215aca6a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -215,7 +215,8 @@ def build_output_coords_and_indexes( exclude_dims: AbstractSet = frozenset(), combine_attrs: Literal[ "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] | Callable[..., Any] = "override", + ] + | Callable[..., Any] = "override", ) -> tuple[list[dict[Any, Variable]], list[dict[Any, Index]]]: """Build output coordinates and indexes for an operation. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d40511175ff..b46c2fd9a8c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -23,7 +23,6 @@ Mapping, MutableMapping, Sequence, - Tuple, cast, overload, ) @@ -105,8 +104,9 @@ if TYPE_CHECKING: import os + from ..backends import AbstractDataStore, ZarrStore - from ..backends.api import T_NETCDFTYPES, T_NETCDFENGINE + from ..backends.api import T_NETCDFENGINE, T_NETCDFTYPES from .dataarray import DataArray from .merge import CoercibleMapping from .types import ErrorChoice, ErrorChoiceWithWarn, T_Xarray @@ -4495,7 +4495,8 @@ def merge( fill_value: Any = dtypes.NA, combine_attrs: Literal[ "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] | Callable[..., Any] = "override", + ] + | Callable[..., Any] = "override", ) -> Dataset: """Merge the arrays of two datasets into a single dataset. diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 53c3b1f5f71..6bbb40784be 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -8,12 +8,12 @@ Callable, Hashable, Iterable, + Literal, Mapping, NamedTuple, Optional, Sequence, Tuple, - Literal, Union, ) @@ -212,7 +212,8 @@ def merge_collected( compat: str = "minimal", combine_attrs: Literal[ "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] | Callable[..., Any] = "override", + ] + | Callable[..., Any] = "override", equals: dict[Hashable, bool] = None, ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. @@ -382,7 +383,8 @@ def merge_coordinates_without_align( exclude_dims: AbstractSet = frozenset(), combine_attrs: Literal[ "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] | Callable[..., Any] = "override", + ] + | Callable[..., Any] = "override", ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge variables/indexes from coordinates without automatic alignments. @@ -675,7 +677,8 @@ def merge_core( join: str = "outer", combine_attrs: Literal[ "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] | Callable[..., Any] = "override", + ] + | Callable[..., Any] = "override", priority_arg: int | None = None, explicit_coords: Sequence | None = None, indexes: Mapping[Any, Any] | None = None, @@ -767,7 +770,8 @@ def merge( fill_value: object = dtypes.NA, combine_attrs: Literal[ "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] | Callable[..., Any] = "override", + ] + | Callable[..., Any] = "override", ) -> Dataset: """Merge any number of xarray objects into a single Dataset as variables. @@ -1017,7 +1021,8 @@ def dataset_merge_method( fill_value: Any, combine_attrs: Literal[ "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] | Callable[..., Any], + ] + | Callable[..., Any], ) -> _MergeResult: """Guts of the Dataset.merge method.""" # we are locked into supporting overwrite_vars for the Dataset.merge From 3102f73f270218ea92812bc88eff61a4587f85e7 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 15 May 2022 20:27:07 +0200 Subject: [PATCH 06/45] fix import error --- xarray/backends/api.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 287443fab5c..ac215ca7b44 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -51,6 +51,17 @@ Delayed = None # type: ignore from .common import BackendEntrypoint + T_NETCDFENGINE = Literal["netcdf4", "scipy", "h5netcdf"] + T_ENGINE = Union[ + T_NETCDFENGINE, + Literal["pydap", "pynio", "pseudonetcdf", "cfgrib", "zarr"], + Type[BackendEntrypoint], + ] + T_CHUNKS = Union[int, dict[Any, Any], Literal["auto"], None] + T_NETCDFTYPES = Literal[ + "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" + ] + DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" @@ -66,16 +77,7 @@ "zarr": backends.ZarrStore.open_group, } -T_NETCDFENGINE = Literal["netcdf4", "scipy", "h5netcdf"] -T_ENGINE = Union[ - T_NETCDFENGINE, - Literal["pydap", "pynio", "pseudonetcdf", "cfgrib", "zarr"], - Type[BackendEntrypoint], -] -T_CHUNKS = Union[int, dict[Any, Any], Literal["auto"], None] -T_NETCDFTYPES = Literal[ - "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" -] + def _get_default_engine_remote_uri() -> Literal["netcdf4", "pydap"]: engine: Literal["netcdf4", "pydap"] From 2007cd53cefdf0d5b6db2f2577284718412d88f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 May 2022 18:29:34 +0000 Subject: [PATCH 07/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/backends/api.py | 5 +---- xarray/backends/common.py | 2 +- xarray/core/combine.py | 6 ++---- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 7492276c4bf..429c6217a6d 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -8,17 +8,13 @@ TYPE_CHECKING, Any, Callable, - Dict, Final, Hashable, Iterable, - List, Literal, Mapping, MutableMapping, - Optional, Sequence, - Tuple, Type, Union, cast, @@ -77,6 +73,7 @@ "zarr": backends.ZarrStore.open_group, } + def _get_default_engine_remote_uri() -> Literal["netcdf4", "pydap"]: engine: Literal["netcdf4", "pydap"] try: diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 170ec47e510..52738c639e1 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -4,7 +4,7 @@ import os import time import traceback -from typing import Any, Dict, Tuple, Type, Union +from typing import Any import numpy as np diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 6260fd01a3e..60446425f66 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -3,7 +3,7 @@ import itertools import warnings from collections import Counter -from typing import Any, Callable, Iterable, List, Literal, Sequence, Union +from typing import Any, Callable, Iterable, Literal, Sequence, Union import pandas as pd @@ -379,9 +379,7 @@ def _nested_combine( def combine_nested( datasets: DATASET_HYPERCUBE, - concat_dim: ( - str | DataArray | None | Sequence[Union[str, DataArray, pd.Index, None]] - ), + concat_dim: (str | DataArray | None | Sequence[str | DataArray | pd.Index | None]), compat: str = "no_conflicts", data_vars: str = "all", coords: str = "different", From c7dd1884a21292ef84fb9c3890ac2e0d20978d9f Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 15 May 2022 20:31:14 +0200 Subject: [PATCH 08/45] replace tuple[x] by Tuple[x] for py3.8 --- xarray/backends/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 7492276c4bf..d6721a29f8c 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1049,7 +1049,7 @@ def to_netcdf( compute: bool, multifile: Literal[True], invalid_netcdf: bool, -) -> tuple[ArrayWriter, AbstractDataStore]: +) -> Tuple[ArrayWriter, AbstractDataStore]: ... @@ -1116,7 +1116,7 @@ def to_netcdf( compute: bool = True, multifile: bool = False, invalid_netcdf: bool = False, -) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: +) -> Tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file From befd121fc91005dcf325e35f97a15bc40a94edf6 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 15 May 2022 20:38:51 +0200 Subject: [PATCH 09/45] fix some merge errors --- xarray/backends/api.py | 8 ++++---- xarray/core/dataset.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index ed6a9368ab4..fb160ab83c2 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1026,7 +1026,7 @@ def multi_file_closer(): return combined -WRITEABLE_STORES: dict[str, Callable] = { +WRITEABLE_STORES: dict[T_NETCDFENGINE, Callable] = { "netcdf4": backends.NetCDF4DataStore.open, "scipy": backends.ScipyDataStore, "h5netcdf": backends.H5NetCDFStore.open, @@ -1046,7 +1046,7 @@ def to_netcdf( compute: bool, multifile: Literal[True], invalid_netcdf: bool, -) -> Tuple[ArrayWriter, AbstractDataStore]: +) -> tuple[ArrayWriter, AbstractDataStore]: ... @@ -1113,7 +1113,7 @@ def to_netcdf( compute: bool = True, multifile: bool = False, invalid_netcdf: bool = False, -) -> Tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: +) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file @@ -1443,7 +1443,7 @@ def check_dtype(vname, var): def to_zarr( dataset: Dataset, - store: MutableMapping | str | os.PathLike = None, + store: MutableMapping | str | os.PathLike | None = None, chunk_store=None, mode: str = None, synchronizer=None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b46c2fd9a8c..d1be1eeaf2b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -32,7 +32,6 @@ import xarray as xr -from ..backends.common import ArrayWriter from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings from ..plot.dataset_plot import _Dataset_PlotMethods From 86c5536d2b9f322f65646f623260eda1d3bea1a5 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 15 May 2022 21:02:38 +0200 Subject: [PATCH 10/45] add overloads to to_zarr --- xarray/backends/api.py | 52 +++++++++++++++++++++++++++----- xarray/core/dataset.py | 68 ++++++++++++++++++++++++++++++++---------- 2 files changed, 97 insertions(+), 23 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index fb160ab83c2..a466af42d1a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1441,20 +1441,58 @@ def check_dtype(vname, var): check_dtype(vname, var) +@overload +def to_zarr( + dataset: Dataset, + store: MutableMapping | str | os.PathLike | None, + chunk_store: MutableMapping | str | os.PathLike | None, + mode: Literal["w", "w-", "a", "r+", None], + synchronizer, + group: str | None, + encoding: Mapping | None, + compute: Literal[True], + consolidated: bool | None, + append_dim: Hashable | None, + region: Mapping[str, slice] | None, + safe_chunks: bool, + storage_options: dict[str, str] | None, +) -> ZarrStore: + ... + + +@overload +def to_zarr( + dataset: Dataset, + store: MutableMapping | str | os.PathLike | None, + chunk_store: MutableMapping | str | os.PathLike | None, + mode: Literal["w", "w-", "a", "r+", None], + synchronizer, + group: str | None, + encoding: Mapping | None, + compute: Literal[False], + consolidated: bool | None, + append_dim: Hashable | None, + region: Mapping[str, slice] | None, + safe_chunks: bool, + storage_options: dict[str, str] | None, +) -> Delayed: + ... + + def to_zarr( dataset: Dataset, store: MutableMapping | str | os.PathLike | None = None, - chunk_store=None, - mode: str = None, + chunk_store: MutableMapping | str | os.PathLike | None = None, + mode: Literal["w", "w-", "a", "r+", None] = None, synchronizer=None, - group: str = None, - encoding: Mapping = None, + group: str | None = None, + encoding: Mapping | None = None, compute: bool = True, consolidated: bool | None = None, - append_dim: Hashable = None, - region: Mapping[str, slice] = None, + append_dim: Hashable | None= None, + region: Mapping[str, slice] | None = None, safe_chunks: bool = True, - storage_options: dict[str, str] = None, + storage_options: dict[str, str] | None = None, ) -> ZarrStore | Delayed: """This function creates an appropriate datastore for writing a dataset to a zarr ztore diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d1be1eeaf2b..bdc68f20285 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -102,8 +102,6 @@ ) if TYPE_CHECKING: - import os - from ..backends import AbstractDataStore, ZarrStore from ..backends.api import T_NETCDFENGINE, T_NETCDFTYPES from .dataarray import DataArray @@ -1697,7 +1695,7 @@ def to_netcdf( @overload def to_netcdf( self, - path: str | os.PathLike, + path: str | PathLike, mode: Literal["w", "a"], format: T_NETCDFTYPES | None, group: str | None, @@ -1712,7 +1710,7 @@ def to_netcdf( @overload def to_netcdf( self, - path: str | os.PathLike, + path: str | PathLike, mode: Literal["w", "a"], format: T_NETCDFTYPES | None, group: str | None, @@ -1726,7 +1724,7 @@ def to_netcdf( def to_netcdf( self, - path: str | os.PathLike | None = None, + path: str | PathLike | None = None, mode: Literal["w", "a"] = "w", format: T_NETCDFTYPES | None = None, group: str | None = None, @@ -1826,21 +1824,57 @@ def to_netcdf( invalid_netcdf=invalid_netcdf, ) + @overload + def to_zarr( + self, + store: MutableMapping | str | PathLike | None, + chunk_store: MutableMapping | str | PathLike | None, + mode: Literal["w", "w-", "a", "r+", None], + synchronizer, + group: str | None, + encoding: Mapping | None, + compute: Literal[False], + consolidated: bool | None, + append_dim: Hashable | None, + region: Mapping[str, slice] | None, + safe_chunks: bool, + storage_options: dict[str, str] | None, + ) -> Delayed: + ... + + @overload + def to_zarr( + self, + store: MutableMapping | str | PathLike | None, + chunk_store: MutableMapping | str | PathLike | None, + mode: Literal["w", "w-", "a", "r+", None], + synchronizer, + group: str | None, + encoding: Mapping | None, + compute: Literal[True], + consolidated: bool | None, + append_dim: Hashable | None, + region: Mapping[str, slice] | None, + safe_chunks: bool, + storage_options: dict[str, str] | None, + ) -> ZarrStore: + ... + def to_zarr( self, store: MutableMapping | str | PathLike | None = None, chunk_store: MutableMapping | str | PathLike | None = None, - mode: str = None, + mode: Literal["w", "w-", "a", "r+", None] = None, synchronizer=None, - group: str = None, - encoding: Mapping = None, + group: str | None = None, + encoding: Mapping | None = None, compute: bool = True, consolidated: bool | None = None, - append_dim: Hashable = None, - region: Mapping[str, slice] = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice] | None = None, safe_chunks: bool = True, - storage_options: dict[str, str] = None, - ) -> ZarrStore: + storage_options: dict[str, str] | None = None, + ) -> ZarrStore | Delayed: """Write dataset contents to a zarr group. Zarr chunks are determined in the following way: @@ -1921,6 +1955,11 @@ def to_zarr( Any additional parameters for the storage backend (ignored for local paths). + Returns + ------- + * ``dask.delayed.Delayed`` if compute is False + * ZarrStore otherwise + References ---------- https://zarr.readthedocs.io/ @@ -1945,10 +1984,7 @@ def to_zarr( """ from ..backends.api import to_zarr - if encoding is None: - encoding = {} - - return to_zarr( + return to_zarr( # type: ignore self, store=store, chunk_store=chunk_store, From e5b7e484547fda33a43aa20e26444a4b36b00984 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 May 2022 19:04:16 +0000 Subject: [PATCH 11/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/backends/api.py | 4 ++-- xarray/core/dataset.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index a466af42d1a..31bb2a3d3c4 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1459,7 +1459,7 @@ def to_zarr( ) -> ZarrStore: ... - + @overload def to_zarr( dataset: Dataset, @@ -1489,7 +1489,7 @@ def to_zarr( encoding: Mapping | None = None, compute: bool = True, consolidated: bool | None = None, - append_dim: Hashable | None= None, + append_dim: Hashable | None = None, region: Mapping[str, slice] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bdc68f20285..83ced128f02 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1841,7 +1841,7 @@ def to_zarr( storage_options: dict[str, str] | None, ) -> Delayed: ... - + @overload def to_zarr( self, From a60e7a6d08487b280fb23dd823d6ec598988858e Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 16 May 2022 19:16:46 +0200 Subject: [PATCH 12/45] fix absolute import --- xarray/backends/api.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 31bb2a3d3c4..7d8d1d692c1 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -23,7 +23,6 @@ import numpy as np -from xarray.backends.zarr import ZarrStore from .. import backends, conventions from ..core import indexing @@ -1456,7 +1455,7 @@ def to_zarr( region: Mapping[str, slice] | None, safe_chunks: bool, storage_options: dict[str, str] | None, -) -> ZarrStore: +) -> backends.ZarrStore: ... @@ -1493,7 +1492,7 @@ def to_zarr( region: Mapping[str, slice] | None = None, safe_chunks: bool = True, storage_options: dict[str, str] | None = None, -) -> ZarrStore | Delayed: +) -> backends.ZarrStore | Delayed: """This function creates an appropriate datastore for writing a dataset to a zarr ztore From 6f7c875878457f7ba1722e07f854b4fdfb5616b3 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 16 May 2022 19:19:58 +0200 Subject: [PATCH 13/45] CamelCase type vars --- xarray/backends/api.py | 46 +++++++++++++++++++++--------------------- xarray/core/dataset.py | 18 ++++++++--------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 7d8d1d692c1..6a253de5143 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -46,14 +46,14 @@ Delayed = None # type: ignore from .common import BackendEntrypoint - T_NETCDFENGINE = Literal["netcdf4", "scipy", "h5netcdf"] - T_ENGINE = Union[ - T_NETCDFENGINE, + T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] + T_Engine = Union[ + T_NetcdfEngine, Literal["pydap", "pynio", "pseudonetcdf", "cfgrib", "zarr"], Type[BackendEntrypoint], ] - T_CHUNKS = Union[int, dict[Any, Any], Literal["auto"], None] - T_NETCDFTYPES = Literal[ + T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None] + T_NetcdfTypes = Literal[ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" ] @@ -121,7 +121,7 @@ def _get_default_engine_netcdf() -> Literal["netcdf4", "scipy"]: return engine -def _get_default_engine(path: str, allow_remote: bool = False) -> T_NETCDFENGINE: +def _get_default_engine(path: str, allow_remote: bool = False) -> T_NetcdfEngine: if allow_remote and is_remote_uri(path): return _get_default_engine_remote_uri() # type: ignore[return-value] elif path.endswith(".gz"): @@ -361,8 +361,8 @@ def _dataset_from_backend_dataset( def open_dataset( filename_or_obj: str | os.PathLike, *, - engine: T_ENGINE = None, - chunks: T_CHUNKS = None, + engine: T_Engine = None, + chunks: T_Chunks = None, cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | None = None, @@ -544,8 +544,8 @@ def open_dataset( def open_dataarray( filename_or_obj: str | os.PathLike, *, - engine: T_ENGINE = None, - chunks: T_CHUNKS = None, + engine: T_Engine = None, + chunks: T_Chunks = None, cache: bool | None = None, decode_cf: bool | None = None, mask_and_scale: bool | None = None, @@ -723,7 +723,7 @@ def open_dataarray( def open_mfdataset( paths: str | Iterable[str | os.PathLike], - chunks: T_CHUNKS = None, + chunks: T_Chunks = None, concat_dim: str | DataArray | Index @@ -735,7 +735,7 @@ def open_mfdataset( "identical", "equals", "broadcast_equals", "no_conflicts", "override" ] = "no_conflicts", preprocess: Callable[[Dataset], Dataset] | None = None, - engine: T_ENGINE = None, + engine: T_Engine = None, data_vars: Literal["all", "minimal", "different"] | list[str] = "all", coords="different", combine: Literal["by_coords", "nested"] = "by_coords", @@ -1025,7 +1025,7 @@ def multi_file_closer(): return combined -WRITEABLE_STORES: dict[T_NETCDFENGINE, Callable] = { +WRITEABLE_STORES: dict[T_NetcdfEngine, Callable] = { "netcdf4": backends.NetCDF4DataStore.open, "scipy": backends.ScipyDataStore, "h5netcdf": backends.H5NetCDFStore.open, @@ -1037,9 +1037,9 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None, mode: Literal["w", "a"], - format: T_NETCDFTYPES | None, + format: T_NetcdfTypes | None, group: str | None, - engine: T_NETCDFENGINE | None, + engine: T_NetcdfEngine | None, encoding: Mapping[Hashable, Mapping[str, Any]] | None, unlimited_dims: Iterable[Hashable] | None, compute: bool, @@ -1054,9 +1054,9 @@ def to_netcdf( dataset: Dataset, path_or_file: None, mode: Literal["w", "a"], - format: T_NETCDFTYPES | None, + format: T_NetcdfTypes | None, group: str | None, - engine: T_NETCDFENGINE | None, + engine: T_NetcdfEngine | None, encoding: Mapping[Hashable, Mapping[str, Any]] | None, unlimited_dims: Iterable[Hashable] | None, compute: bool, @@ -1071,9 +1071,9 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, mode: Literal["w", "a"], - format: T_NETCDFTYPES | None, + format: T_NetcdfTypes | None, group: str | None, - engine: T_NETCDFENGINE | None, + engine: T_NetcdfEngine | None, encoding: Mapping[Hashable, Mapping[str, Any]] | None, unlimited_dims: Iterable[Hashable] | None, compute: Literal[False], @@ -1088,9 +1088,9 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, mode: Literal["w", "a"], - format: T_NETCDFTYPES | None, + format: T_NetcdfTypes | None, group: str | None, - engine: T_NETCDFENGINE | None, + engine: T_NetcdfEngine | None, encoding: Mapping[Hashable, Mapping[str, Any]] | None, unlimited_dims: Iterable[Hashable] | None, compute: Literal[True], @@ -1104,9 +1104,9 @@ def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike | None = None, mode: Literal["w", "a"] = "w", - format: T_NETCDFTYPES | None = None, + format: T_NetcdfTypes | None = None, group: str | None = None, - engine: T_NETCDFENGINE | None = None, + engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, compute: bool = True, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 83ced128f02..1d6a27734f5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -103,7 +103,7 @@ if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore - from ..backends.api import T_NETCDFENGINE, T_NETCDFTYPES + from ..backends.api import T_NetcdfEngine, T_NetcdfTypes from .dataarray import DataArray from .merge import CoercibleMapping from .types import ErrorChoice, ErrorChoiceWithWarn, T_Xarray @@ -1682,9 +1682,9 @@ def to_netcdf( self, path: None, mode: Literal["w", "a"], - format: T_NETCDFTYPES | None, + format: T_NetcdfTypes | None, group: str | None, - engine: T_NETCDFENGINE | None, + engine: T_NetcdfEngine | None, encoding: Mapping[Hashable, Mapping[str, Any]] | None, unlimited_dims: Iterable[Hashable] | None, compute: bool, @@ -1697,9 +1697,9 @@ def to_netcdf( self, path: str | PathLike, mode: Literal["w", "a"], - format: T_NETCDFTYPES | None, + format: T_NetcdfTypes | None, group: str | None, - engine: T_NETCDFENGINE | None, + engine: T_NetcdfEngine | None, encoding: Mapping[Hashable, Mapping[str, Any]] | None, unlimited_dims: Iterable[Hashable] | None, compute: Literal[False], @@ -1712,9 +1712,9 @@ def to_netcdf( self, path: str | PathLike, mode: Literal["w", "a"], - format: T_NETCDFTYPES | None, + format: T_NetcdfTypes | None, group: str | None, - engine: T_NETCDFENGINE | None, + engine: T_NetcdfEngine | None, encoding: Mapping[Hashable, Mapping[str, Any]] | None, unlimited_dims: Iterable[Hashable] | None, compute: Literal[True], @@ -1726,9 +1726,9 @@ def to_netcdf( self, path: str | PathLike | None = None, mode: Literal["w", "a"] = "w", - format: T_NETCDFTYPES | None = None, + format: T_NetcdfTypes | None = None, group: str | None = None, - engine: T_NETCDFENGINE | None = None, + engine: T_NetcdfEngine | None = None, encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, compute: bool = True, From 6f454eead61bb563046ab0d955fac3650a6e66bf Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 16 May 2022 19:46:54 +0200 Subject: [PATCH 14/45] move some literal type to core.types --- xarray/backends/api.py | 10 ++---- xarray/core/combine.py | 12 +++---- xarray/core/concat.py | 28 +++++++-------- xarray/core/dataarray.py | 14 ++++---- xarray/core/dataset.py | 28 +++++++-------- xarray/core/indexes.py | 6 ++-- xarray/core/merge.py | 75 +++++++++++++++++++++++----------------- xarray/core/types.py | 14 ++++++-- xarray/core/utils.py | 8 ++--- xarray/core/variable.py | 6 ++-- 10 files changed, 107 insertions(+), 94 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 6a253de5143..ee28f4008f6 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -45,6 +45,7 @@ except ImportError: Delayed = None # type: ignore from .common import BackendEntrypoint + from ..core.types import CompatOptions, CombineAttrsOptions T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] T_Engine = Union[ @@ -731,9 +732,7 @@ def open_mfdataset( | Sequence[DataArray] | Sequence[Index] | None = None, - compat: Literal[ - "identical", "equals", "broadcast_equals", "no_conflicts", "override" - ] = "no_conflicts", + compat: CompatOptions = "no_conflicts", preprocess: Callable[[Dataset], Dataset] | None = None, engine: T_Engine = None, data_vars: Literal["all", "minimal", "different"] | list[str] = "all", @@ -742,10 +741,7 @@ def open_mfdataset( parallel: bool = False, join: Literal["outer", "inner", "left", "right", "exact", "override"] = "outer", attrs_file: str | os.PathLike | None = None, - combine_attrs: Literal[ - "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] - | Callable[..., Any] = "override", + combine_attrs: CombineAttrsOptions = "override", **kwargs, ) -> Dataset: """Open multiple files as a single dataset. diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 60446425f66..eb748b4ed3b 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -3,7 +3,7 @@ import itertools import warnings from collections import Counter -from typing import Any, Callable, Iterable, Literal, Sequence, Union +from typing import Any, Callable, Iterable, Literal, Sequence, TYPE_CHECKING, Union import pandas as pd @@ -14,6 +14,9 @@ from .merge import merge from .utils import iterate_nested +if TYPE_CHECKING: + from .types import CombineAttrsOptions, CompatOptions + def _infer_concat_order_from_positions(datasets): return dict(_infer_tile_ids_from_nested_list(datasets, ())) @@ -660,15 +663,12 @@ def _combine_single_variable_hypercube( # TODO remove empty list default param after version 0.21, see PR4696 def combine_by_coords( data_objects: Sequence[Dataset | DataArray] = [], - compat: str = "no_conflicts", + compat: CompatOptions = "no_conflicts", data_vars: Literal["all", "minimal", "different"] | list[str] = "all", coords: str = "different", fill_value: object = dtypes.NA, join: str = "outer", - combine_attrs: Literal[ - "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] - | Callable[..., Any] = "no_conflicts", + combine_attrs: CombineAttrsOptions = "no_conflicts", datasets: Sequence[Dataset] = None, ) -> Dataset | DataArray: """ diff --git a/xarray/core/concat.py b/xarray/core/concat.py index df493ad1c5e..1c8fcb98192 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -4,6 +4,8 @@ import pandas as pd +from xarray.core.types import CombineAttrsOptions + from . import dtypes, utils from .alignment import align from .duck_array_ops import lazy_array_equiv @@ -20,20 +22,16 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - -compat_options = Literal[ - "identical", "equals", "broadcast_equals", "no_conflicts", "override" -] -concat_options = Literal["all", "minimal", "different"] + from .types import CompatOptions, ConcatOptions @overload def concat( objs: Iterable[Dataset], dim: Hashable | DataArray | pd.Index, - data_vars: concat_options | list[Hashable] = "all", - coords: concat_options | list[Hashable] = "different", - compat: compat_options = "equals", + data_vars: ConcatOptions | list[Hashable] = "all", + coords: ConcatOptions | list[Hashable] = "different", + compat: CompatOptions = "equals", positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: str = "outer", @@ -46,9 +44,9 @@ def concat( def concat( objs: Iterable[DataArray], dim: Hashable | DataArray | pd.Index, - data_vars: concat_options | list[Hashable] = "all", - coords: concat_options | list[Hashable] = "different", - compat: compat_options = "equals", + data_vars: ConcatOptions | list[Hashable] = "all", + coords: ConcatOptions | list[Hashable] = "different", + compat: CompatOptions = "equals", positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, join: str = "outer", @@ -420,11 +418,11 @@ def _dataset_concat( dim: str | DataArray | pd.Index, data_vars: str | list[str], coords: str | list[str], - compat: str, + compat: CompatOptions, positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, join: str = "outer", - combine_attrs: str = "override", + combine_attrs: CombineAttrsOptions = "override", ) -> Dataset: """ Concatenate a sequence of datasets along a new or existing dimension @@ -609,11 +607,11 @@ def _dataarray_concat( dim: str | DataArray | pd.Index, data_vars: str | list[str], coords: str | list[str], - compat: str, + compat: CompatOptions, positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, join: str = "outer", - combine_attrs: str = "override", + combine_attrs: CombineAttrsOptions = "override", ) -> DataArray: from .dataarray import DataArray diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 64c4e419788..4fc5f7c0c5c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -78,7 +78,7 @@ except ImportError: iris_Cube = None - from .types import ErrorChoice, ErrorChoiceWithWarn, T_DataArray, T_Xarray + from .types import ErrorOptions, ErrorOptionsWithWarn, T_DataArray, T_Xarray def _infer_coords_and_dims( @@ -1186,7 +1186,7 @@ def isel( self, indexers: Mapping[Any, Any] = None, drop: bool = False, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, ) -> DataArray: """Return a new DataArray whose data is given by integer indexing @@ -2350,7 +2350,7 @@ def transpose( self, *dims: Hashable, transpose_coords: bool = True, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", ) -> DataArray: """Return a new DataArray object with transposed dimensions. @@ -2401,7 +2401,7 @@ def T(self) -> DataArray: return self.transpose() def drop_vars( - self, names: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise" + self, names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise" ) -> DataArray: """Returns an array with dropped variables. @@ -2427,7 +2427,7 @@ def drop( labels: Mapping = None, dim: Hashable = None, *, - errors: ErrorChoice = "raise", + errors: ErrorOptions = "raise", **labels_kwargs, ) -> DataArray: """Backward compatible method based on `drop_vars` and `drop_sel` @@ -2446,7 +2446,7 @@ def drop_sel( self, labels: Mapping[Any, Any] = None, *, - errors: ErrorChoice = "raise", + errors: ErrorOptions = "raise", **labels_kwargs, ) -> DataArray: """Drop index labels from this DataArray. @@ -4604,7 +4604,7 @@ def query( queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", **queries_kwargs: Any, ) -> DataArray: """Return a new data array indexed along the specified diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1d6a27734f5..b6da496b921 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -31,6 +31,7 @@ import pandas as pd import xarray as xr +from xarray.core.types import CombineAttrsOptions, CompatOptions from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings @@ -106,7 +107,7 @@ from ..backends.api import T_NetcdfEngine, T_NetcdfTypes from .dataarray import DataArray from .merge import CoercibleMapping - from .types import ErrorChoice, ErrorChoiceWithWarn, T_Xarray + from .types import ErrorOptions, ErrorOptionsWithWarn, T_Xarray try: from dask.delayed import Delayed @@ -2154,7 +2155,7 @@ def chunk( return self._replace(variables) def _validate_indexers( - self, indexers: Mapping[Any, Any], missing_dims: ErrorChoiceWithWarn = "raise" + self, indexers: Mapping[Any, Any], missing_dims: ErrorOptionsWithWarn = "raise" ) -> Iterator[tuple[Hashable, int | slice | np.ndarray | Variable]]: """Here we make sure + indexer has a valid keys @@ -2259,7 +2260,7 @@ def isel( self, indexers: Mapping[Any, Any] = None, drop: bool = False, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, ) -> Dataset: """Returns a new dataset with each array indexed along the specified @@ -2350,7 +2351,7 @@ def _isel_fancy( indexers: Mapping[Any, Any], *, drop: bool, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", ) -> Dataset: valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) @@ -4525,13 +4526,10 @@ def merge( self, other: CoercibleMapping | DataArray, overwrite_vars: Hashable | Iterable[Hashable] = frozenset(), - compat: str = "no_conflicts", + compat: CompatOptions = "no_conflicts", join: str = "outer", fill_value: Any = dtypes.NA, - combine_attrs: Literal[ - "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] - | Callable[..., Any] = "override", + combine_attrs: CombineAttrsOptions = "override", ) -> Dataset: """Merge the arrays of two datasets into a single dataset. @@ -4627,7 +4625,7 @@ def _assert_all_in_dataset( ) def drop_vars( - self, names: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise" + self, names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise" ) -> Dataset: """Drop variables from this dataset. @@ -4680,7 +4678,7 @@ def drop_vars( ) def drop( - self, labels=None, dim=None, *, errors: ErrorChoice = "raise", **labels_kwargs + self, labels=None, dim=None, *, errors: ErrorOptions = "raise", **labels_kwargs ): """Backward compatible method based on `drop_vars` and `drop_sel` @@ -4730,7 +4728,7 @@ def drop( ) return self.drop_sel(labels, errors=errors) - def drop_sel(self, labels=None, *, errors: ErrorChoice = "raise", **labels_kwargs): + def drop_sel(self, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs): """Drop index labels from this dataset. Parameters @@ -4865,7 +4863,7 @@ def drop_isel(self, indexers=None, **indexers_kwargs): return ds def drop_dims( - self, drop_dims: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise" + self, drop_dims: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise" ) -> Dataset: """Drop dimensions and associated variables from this dataset. @@ -4905,7 +4903,7 @@ def drop_dims( def transpose( self, *dims: Hashable, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", ) -> Dataset: """Return a new Dataset object with all array dimensions transposed. @@ -7839,7 +7837,7 @@ def query( queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", **queries_kwargs: Any, ) -> Dataset: """Return a new dataset with each array indexed along the specified diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9884a756fe6..ee3ef17ed65 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -25,7 +25,7 @@ from .utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar if TYPE_CHECKING: - from .types import ErrorChoice, T_Index + from .types import ErrorOptions, T_Index from .variable import Variable IndexVars = Dict[Any, "Variable"] @@ -1098,7 +1098,7 @@ def is_multi(self, key: Hashable) -> bool: return len(self._id_coord_names[self._coord_name_id[key]]) > 1 def get_all_coords( - self, key: Hashable, errors: ErrorChoice = "raise" + self, key: Hashable, errors: ErrorOptions = "raise" ) -> dict[Hashable, Variable]: """Return all coordinates having the same index. @@ -1129,7 +1129,7 @@ def get_all_coords( return {k: self._variables[k] for k in all_coord_names} def get_all_dims( - self, key: Hashable, errors: ErrorChoice = "raise" + self, key: Hashable, errors: ErrorOptions = "raise" ) -> Mapping[Hashable, int]: """Return all dimensions shared by an index. diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 6bbb40784be..758f5d72a7c 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -36,6 +36,7 @@ from .coordinates import Coordinates from .dataarray import DataArray from .dataset import Dataset + from .types import CompatOptions, CombineAttrsOptions DimsLike = Union[Hashable, Sequence[Hashable]] ArrayLike = Any @@ -96,8 +97,8 @@ class MergeError(ValueError): def unique_variable( name: Hashable, variables: list[Variable], - compat: str = "broadcast_equals", - equals: bool = None, + compat: CompatOptions = "broadcast_equals", + equals: bool | None = None, ) -> Variable: """Return the unique variable from a list of variables or raise MergeError. @@ -209,12 +210,9 @@ def _assert_prioritized_valid( def merge_collected( grouped: dict[Hashable, list[MergeElement]], prioritized: Mapping[Any, MergeElement] = None, - compat: str = "minimal", - combine_attrs: Literal[ - "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] - | Callable[..., Any] = "override", - equals: dict[Hashable, bool] = None, + compat: CompatOptions = "minimal", + combine_attrs: CombineAttrsOptions = "override", + equals: dict[Hashable, bool] | None = None, ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge dicts of variables, while resolving conflicts appropriately. @@ -224,6 +222,22 @@ def merge_collected( prioritized : mapping compat : str Type of equality check to use when checking for conflicts. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ + "override"} or callable, default: "override" + A callable or a string indicating how to combine attrs of the objects being + merged: + + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have + the same name must also have the same value. + - "drop_conflicts": attrs from all objects are combined, any that have + the same name but different values are dropped. + - "override": skip comparing and copy attrs from the first dataset to + the result. + + If a callable, it must expect a sequence of ``attrs`` dicts and a context object + as its only parameters. equals : mapping, optional corresponding to result of compat test @@ -381,10 +395,7 @@ def merge_coordinates_without_align( objects: list[Coordinates], prioritized: Mapping[Any, MergeElement] = None, exclude_dims: AbstractSet = frozenset(), - combine_attrs: Literal[ - "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] - | Callable[..., Any] = "override", + combine_attrs: CombineAttrsOptions = "override", ) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]: """Merge variables/indexes from coordinates without automatic alignments. @@ -488,7 +499,7 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik def _get_priority_vars_and_indexes( - objects: list[DatasetLike], priority_arg: int | None, compat: str = "equals" + objects: list[DatasetLike], priority_arg: int | None, compat: CompatOptions = "equals" ) -> dict[Hashable, MergeElement]: """Extract the priority variable from a list of mappings. @@ -502,8 +513,19 @@ def _get_priority_vars_and_indexes( Dictionaries in which to find the priority variables. priority_arg : int or None Integer object whose variable should take priority. - compat : {"identical", "equals", "broadcast_equals", "no_conflicts"}, optional - Compatibility checks to use when merging variables. + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional + String indicating how to compare non-concatenated variables of the same name for + potential conflicts. This is passed down to merge. + + - "broadcast_equals": all values must be equal when variables are + broadcast against each other to ensure common dimensions. + - "equals": all values and dimensions must be the same. + - "identical": all values, dimensions and attributes must be the + same. + - "no_conflicts": only values which are not null in both datasets + must be equal. The returned dataset then contains the combination + of all non-null values. + - "override": skip comparing and pick variable from first dataset Returns ------- @@ -522,7 +544,7 @@ def _get_priority_vars_and_indexes( def merge_coords( objects: Iterable[CoercibleMapping], - compat: str = "minimal", + compat: CompatOptions = "minimal", join: str = "outer", priority_arg: int | None = None, indexes: Mapping[Any, Index] | None = None, @@ -673,12 +695,9 @@ class _MergeResult(NamedTuple): def merge_core( objects: Iterable[CoercibleMapping], - compat: str = "broadcast_equals", + compat: CompatOptions = "broadcast_equals", join: str = "outer", - combine_attrs: Literal[ - "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] - | Callable[..., Any] = "override", + combine_attrs: CombineAttrsOptions = "override", priority_arg: int | None = None, explicit_coords: Sequence | None = None, indexes: Mapping[Any, Any] | None = None, @@ -765,13 +784,10 @@ def merge_core( def merge( objects: Iterable[DataArray | CoercibleMapping], - compat: str = "no_conflicts", + compat: CompatOptions = "no_conflicts", join: str = "outer", fill_value: object = dtypes.NA, - combine_attrs: Literal[ - "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] - | Callable[..., Any] = "override", + combine_attrs: CombineAttrsOptions = "override", ) -> Dataset: """Merge any number of xarray objects into a single Dataset as variables. @@ -1016,13 +1032,10 @@ def dataset_merge_method( dataset: Dataset, other: CoercibleMapping, overwrite_vars: Hashable | Iterable[Hashable], - compat: str, + compat: CompatOptions, join: str, fill_value: Any, - combine_attrs: Literal[ - "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] - | Callable[..., Any], + combine_attrs: CombineAttrsOptions, ) -> _MergeResult: """Guts of the Dataset.merge method.""" # we are locked into supporting overwrite_vars for the Dataset.merge diff --git a/xarray/core/types.py b/xarray/core/types.py index 6dbc57ce797..6e4c3523925 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal, TypeVar, Union +from typing import Any, Callable, TYPE_CHECKING, Literal, TypeVar, Union import numpy as np @@ -34,5 +34,13 @@ VarCompatible = Union["Variable", "ScalarOrArray"] GroupByIncompatible = Union["Variable", "GroupBy"] -ErrorChoice = Literal["raise", "ignore"] -ErrorChoiceWithWarn = Literal["raise", "warn", "ignore"] +ErrorOptions = Literal["raise", "ignore"] +ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] +CompatOptions = Literal[ + "identical", "equals", "broadcast_equals", "no_conflicts", "override", "minimal" +] +ConcatOptions = Literal["all", "minimal", "different"] +CombineAttrsOptions = Union[ + Literal["drop", "identical", "no_conflicts", "drop_conflicts", "override"], + Callable[..., Any], +] \ No newline at end of file diff --git a/xarray/core/utils.py b/xarray/core/utils.py index eda08becc20..a2f95123892 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -30,7 +30,7 @@ import pandas as pd if TYPE_CHECKING: - from .types import ErrorChoiceWithWarn + from .types import ErrorOptionsWithWarn K = TypeVar("K") V = TypeVar("V") @@ -761,7 +761,7 @@ def __len__(self) -> int: def infix_dims( dims_supplied: Collection, dims_all: Collection, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", ) -> Iterator: """ Resolves a supplied list containing an ellipsis representing other items, to @@ -809,7 +809,7 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: def drop_dims_from_indexers( indexers: Mapping[Any, Any], dims: list | Mapping[Any, int], - missing_dims: ErrorChoiceWithWarn, + missing_dims: ErrorOptionsWithWarn, ) -> Mapping[Hashable, Any]: """Depending on the setting of missing_dims, drop any dimensions from indexers that are not present in dims. @@ -855,7 +855,7 @@ def drop_dims_from_indexers( def drop_missing_dims( - supplied_dims: Collection, dims: Collection, missing_dims: ErrorChoiceWithWarn + supplied_dims: Collection, dims: Collection, missing_dims: ErrorOptionsWithWarn ) -> Collection: """Depending on the setting of missing_dims, drop any dimensions from supplied_dims that are not present in dims. diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1e684a72984..20f6bae8ad5 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -59,7 +59,7 @@ BASIC_INDEXING_TYPES = integer_types + (slice,) if TYPE_CHECKING: - from .types import ErrorChoiceWithWarn, T_Variable + from .types import ErrorOptionsWithWarn, T_Variable class MissingDimensionsError(ValueError): @@ -1172,7 +1172,7 @@ def _to_dense(self): def isel( self: T_Variable, indexers: Mapping[Any, Any] = None, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, ) -> T_Variable: """Return a new array indexed along the specified dimension(s). @@ -1449,7 +1449,7 @@ def roll(self, shifts=None, **shifts_kwargs): def transpose( self, *dims, - missing_dims: ErrorChoiceWithWarn = "raise", + missing_dims: ErrorOptionsWithWarn = "raise", ) -> Variable: """Return a new Variable object with transposed dimensions. From e0bc3cbc2779c5d0572f3274eec78c9f2590b5f6 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 16 May 2022 20:09:30 +0200 Subject: [PATCH 15/45] add JoinOptions to core.types --- xarray/backends/api.py | 4 ++-- xarray/core/alignment.py | 7 +++--- xarray/core/combine.py | 26 ++++++++++---------- xarray/core/concat.py | 47 +++++++++++++++++++++++++------------ xarray/core/dataset.py | 4 ++-- xarray/core/merge.py | 10 ++++---- xarray/core/types.py | 3 ++- xarray/tests/test_concat.py | 10 ++++---- 8 files changed, 66 insertions(+), 45 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index ee28f4008f6..9547eaa5fde 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -45,7 +45,7 @@ except ImportError: Delayed = None # type: ignore from .common import BackendEntrypoint - from ..core.types import CompatOptions, CombineAttrsOptions + from ..core.types import CompatOptions, CombineAttrsOptions, JoinOptions T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] T_Engine = Union[ @@ -739,7 +739,7 @@ def open_mfdataset( coords="different", combine: Literal["by_coords", "nested"] = "by_coords", parallel: bool = False, - join: Literal["outer", "inner", "left", "right", "exact", "override"] = "outer", + join: JoinOptions = "outer", attrs_file: str | os.PathLike | None = None, combine_attrs: CombineAttrsOptions = "override", **kwargs, diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index e29d2b2a67f..4c507cd18e1 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset + from .types import JoinOptions DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) @@ -557,7 +558,7 @@ def align(self) -> None: def align( *objects: DataAlignable, - join="inner", + join: JoinOptions = "inner", copy=True, indexes=None, exclude=frozenset(), @@ -764,7 +765,7 @@ def align( def deep_align( objects, - join="inner", + join: JoinOptions = "inner", copy=True, indexes=None, exclude=frozenset(), @@ -834,7 +835,7 @@ def is_alignable(obj): if key is no_key: out[position] = aligned_obj else: - out[position][key] = aligned_obj + out[position][key] = aligned_obj # type: ignore[index] # maybe someone can fix this? # something went wrong: we should have replaced all sentinel values for arg in out: diff --git a/xarray/core/combine.py b/xarray/core/combine.py index eb748b4ed3b..ccbea0520d6 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -15,7 +15,7 @@ from .utils import iterate_nested if TYPE_CHECKING: - from .types import CombineAttrsOptions, CompatOptions + from .types import CombineAttrsOptions, CompatOptions, JoinOptions def _infer_concat_order_from_positions(datasets): @@ -193,10 +193,10 @@ def _combine_nd( concat_dims, data_vars="all", coords="different", - compat="no_conflicts", + compat: CompatOptions = "no_conflicts", fill_value=dtypes.NA, - join="outer", - combine_attrs="drop", + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "drop", ): """ Combines an N-dimensional structure of datasets into one by applying a @@ -255,10 +255,10 @@ def _combine_all_along_first_dim( dim, data_vars, coords, - compat, + compat: CompatOptions, fill_value=dtypes.NA, - join="outer", - combine_attrs="drop", + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "drop", ): # Group into lines of datasets which must be combined along dim @@ -281,12 +281,12 @@ def _combine_all_along_first_dim( def _combine_1d( datasets, concat_dim, - compat="no_conflicts", + compat: CompatOptions = "no_conflicts", data_vars="all", coords="different", fill_value=dtypes.NA, - join="outer", - combine_attrs="drop", + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "drop", ): """ Applies either concat or merge to 1D list of datasets depending on value @@ -341,7 +341,7 @@ def _nested_combine( coords, ids, fill_value=dtypes.NA, - join="outer", + join: JoinOptions = "outer", combine_attrs="drop", ): @@ -387,7 +387,7 @@ def combine_nested( data_vars: str = "all", coords: str = "different", fill_value: object = dtypes.NA, - join: str = "outer", + join: JoinOptions = "outer", combine_attrs: str = "drop", ) -> Dataset: """ @@ -667,7 +667,7 @@ def combine_by_coords( data_vars: Literal["all", "minimal", "different"] | list[str] = "all", coords: str = "different", fill_value: object = dtypes.NA, - join: str = "outer", + join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "no_conflicts", datasets: Sequence[Dataset] = None, ) -> Dataset | DataArray: diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 1c8fcb98192..8e420a25c8c 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - from .types import CompatOptions, ConcatOptions + from .types import CompatOptions, ConcatOptions, JoinOptions @overload @@ -34,8 +34,8 @@ def concat( compat: CompatOptions = "equals", positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, - join: str = "outer", - combine_attrs: str = "override", + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "override", ) -> Dataset: ... @@ -49,8 +49,8 @@ def concat( compat: CompatOptions = "equals", positions: Iterable[Iterable[int]] | None = None, fill_value: object = dtypes.NA, - join: str = "outer", - combine_attrs: str = "override", + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "override", ) -> DataArray: ... @@ -60,11 +60,11 @@ def concat( dim, data_vars="all", coords="different", - compat="equals", + compat: CompatOptions = "equals", positions=None, fill_value=dtypes.NA, - join="outer", - combine_attrs="override", + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "override", ): """Concatenate xarray objects along a new or existing dimension. @@ -231,17 +231,34 @@ def concat( ) if isinstance(first_obj, DataArray): - f = _dataarray_concat + return _dataarray_concat( + arrays=objs, + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + positions=positions, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs + ) elif isinstance(first_obj, Dataset): - f = _dataset_concat + return _dataarray_concat( + arrays=objs, + dim=dim, + data_vars=data_vars, + coords=coords, + compat=compat, + positions=positions, + fill_value=fill_value, + join=join, + combine_attrs=combine_attrs + ) else: raise TypeError( "can only concatenate xarray Dataset and DataArray " f"objects, got {type(first_obj)}" ) - return f( - objs, dim, data_vars, coords, compat, positions, fill_value, join, combine_attrs - ) def _calc_concat_dim_index( @@ -421,7 +438,7 @@ def _dataset_concat( compat: CompatOptions, positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, - join: str = "outer", + join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", ) -> Dataset: """ @@ -610,7 +627,7 @@ def _dataarray_concat( compat: CompatOptions, positions: Iterable[Iterable[int]] | None, fill_value: object = dtypes.NA, - join: str = "outer", + join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", ) -> DataArray: from .dataarray import DataArray diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b6da496b921..e8774563584 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -107,7 +107,7 @@ from ..backends.api import T_NetcdfEngine, T_NetcdfTypes from .dataarray import DataArray from .merge import CoercibleMapping - from .types import ErrorOptions, ErrorOptionsWithWarn, T_Xarray + from .types import ErrorOptions, ErrorOptionsWithWarn, JoinOptions, T_Xarray try: from dask.delayed import Delayed @@ -4527,7 +4527,7 @@ def merge( other: CoercibleMapping | DataArray, overwrite_vars: Hashable | Iterable[Hashable] = frozenset(), compat: CompatOptions = "no_conflicts", - join: str = "outer", + join: JoinOptions = "outer", fill_value: Any = dtypes.NA, combine_attrs: CombineAttrsOptions = "override", ) -> Dataset: diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 758f5d72a7c..c4d50076604 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -36,7 +36,7 @@ from .coordinates import Coordinates from .dataarray import DataArray from .dataset import Dataset - from .types import CompatOptions, CombineAttrsOptions + from .types import CompatOptions, CombineAttrsOptions, JoinOptions DimsLike = Union[Hashable, Sequence[Hashable]] ArrayLike = Any @@ -545,7 +545,7 @@ def _get_priority_vars_and_indexes( def merge_coords( objects: Iterable[CoercibleMapping], compat: CompatOptions = "minimal", - join: str = "outer", + join: JoinOptions = "outer", priority_arg: int | None = None, indexes: Mapping[Any, Index] | None = None, fill_value: object = dtypes.NA, @@ -696,7 +696,7 @@ class _MergeResult(NamedTuple): def merge_core( objects: Iterable[CoercibleMapping], compat: CompatOptions = "broadcast_equals", - join: str = "outer", + join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "override", priority_arg: int | None = None, explicit_coords: Sequence | None = None, @@ -785,7 +785,7 @@ def merge_core( def merge( objects: Iterable[DataArray | CoercibleMapping], compat: CompatOptions = "no_conflicts", - join: str = "outer", + join: JoinOptions = "outer", fill_value: object = dtypes.NA, combine_attrs: CombineAttrsOptions = "override", ) -> Dataset: @@ -1033,7 +1033,7 @@ def dataset_merge_method( other: CoercibleMapping, overwrite_vars: Hashable | Iterable[Hashable], compat: CompatOptions, - join: str, + join: JoinOptions, fill_value: Any, combine_attrs: CombineAttrsOptions, ) -> _MergeResult: diff --git a/xarray/core/types.py b/xarray/core/types.py index 6e4c3523925..e5b35d825ae 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -43,4 +43,5 @@ CombineAttrsOptions = Union[ Literal["drop", "identical", "no_conflicts", "drop_conflicts", "override"], Callable[..., Any], -] \ No newline at end of file +] +JoinOptions = Literal["outer", "inner", "left", "right", "exact", "override"] diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index b87837a442a..378dd5b81c9 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import List +from typing import Any, List, TYPE_CHECKING import numpy as np import pandas as pd @@ -18,6 +18,8 @@ ) from .test_dataset import create_test_data +if TYPE_CHECKING: + from xarray.core.types import JoinOptions, CombineAttrsOptions def test_concat_compat() -> None: ds1 = Dataset( @@ -239,7 +241,7 @@ def test_concat_join_kwarg(self) -> None: ds1 = Dataset({"a": (("x", "y"), [[0]])}, coords={"x": [0], "y": [0]}) ds2 = Dataset({"a": (("x", "y"), [[0]])}, coords={"x": [1], "y": [0.0001]}) - expected = {} + expected: dict[JoinOptions, Any] = {} expected["outer"] = Dataset( {"a": (("x", "y"), [[0, np.nan], [np.nan, 0]])}, {"x": [0, 1], "y": [0, 0.0001]}, @@ -654,7 +656,7 @@ def test_concat_join_kwarg(self) -> None: {"a": (("x", "y"), [[0]])}, coords={"x": [1], "y": [0.0001]} ).to_array() - expected = {} + expected: dict[JoinOptions, Any] = {} expected["outer"] = Dataset( {"a": (("x", "y"), [[0, np.nan], [np.nan, 0]])}, {"x": [0, 1], "y": [0, 0.0001]}, @@ -686,7 +688,7 @@ def test_concat_combine_attrs_kwarg(self) -> None: da1 = DataArray([0], coords=[("x", [0])], attrs={"b": 42}) da2 = DataArray([0], coords=[("x", [1])], attrs={"b": 42, "c": 43}) - expected = {} + expected: dict[CombineAttrsOptions, Any] = {} expected["drop"] = DataArray([0, 0], coords=[("x", [0, 1])]) expected["no_conflicts"] = DataArray( [0, 0], coords=[("x", [0, 1])], attrs={"b": 42, "c": 43} From 2fa193469645f510d3e5a98500faed397164d750 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 May 2022 18:13:26 +0000 Subject: [PATCH 16/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/backends/api.py | 3 +-- xarray/core/combine.py | 2 +- xarray/core/concat.py | 6 +++--- xarray/core/dataset.py | 5 ++++- xarray/core/merge.py | 8 ++++---- xarray/core/types.py | 2 +- xarray/tests/test_concat.py | 5 +++-- 7 files changed, 17 insertions(+), 14 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9547eaa5fde..c480c7bd969 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -23,7 +23,6 @@ import numpy as np - from .. import backends, conventions from ..core import indexing from ..core.combine import ( @@ -44,8 +43,8 @@ from dask.delayed import Delayed except ImportError: Delayed = None # type: ignore + from ..core.types import CombineAttrsOptions, CompatOptions, JoinOptions from .common import BackendEntrypoint - from ..core.types import CompatOptions, CombineAttrsOptions, JoinOptions T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] T_Engine = Union[ diff --git a/xarray/core/combine.py b/xarray/core/combine.py index ccbea0520d6..33798fa3647 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -3,7 +3,7 @@ import itertools import warnings from collections import Counter -from typing import Any, Callable, Iterable, Literal, Sequence, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Iterable, Literal, Sequence, Union import pandas as pd diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 8e420a25c8c..edc4f87640c 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, overload +from typing import TYPE_CHECKING, Any, Hashable, Iterable, overload import pandas as pd @@ -240,7 +240,7 @@ def concat( positions=positions, fill_value=fill_value, join=join, - combine_attrs=combine_attrs + combine_attrs=combine_attrs, ) elif isinstance(first_obj, Dataset): return _dataarray_concat( @@ -252,7 +252,7 @@ def concat( positions=positions, fill_value=fill_value, join=join, - combine_attrs=combine_attrs + combine_attrs=combine_attrs, ) else: raise TypeError( diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e8774563584..8c5d538d06a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4863,7 +4863,10 @@ def drop_isel(self, indexers=None, **indexers_kwargs): return ds def drop_dims( - self, drop_dims: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise" + self, + drop_dims: Hashable | Iterable[Hashable], + *, + errors: ErrorOptions = "raise", ) -> Dataset: """Drop dimensions and associated variables from this dataset. diff --git a/xarray/core/merge.py b/xarray/core/merge.py index c4d50076604..77140cc0aa2 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -5,10 +5,8 @@ TYPE_CHECKING, AbstractSet, Any, - Callable, Hashable, Iterable, - Literal, Mapping, NamedTuple, Optional, @@ -36,7 +34,7 @@ from .coordinates import Coordinates from .dataarray import DataArray from .dataset import Dataset - from .types import CompatOptions, CombineAttrsOptions, JoinOptions + from .types import CombineAttrsOptions, CompatOptions, JoinOptions DimsLike = Union[Hashable, Sequence[Hashable]] ArrayLike = Any @@ -499,7 +497,9 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik def _get_priority_vars_and_indexes( - objects: list[DatasetLike], priority_arg: int | None, compat: CompatOptions = "equals" + objects: list[DatasetLike], + priority_arg: int | None, + compat: CompatOptions = "equals", ) -> dict[Hashable, MergeElement]: """Extract the priority variable from a list of mappings. diff --git a/xarray/core/types.py b/xarray/core/types.py index e5b35d825ae..dc325a986be 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, TYPE_CHECKING, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union import numpy as np diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 378dd5b81c9..28973e20cd0 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, List import numpy as np import pandas as pd @@ -19,7 +19,8 @@ from .test_dataset import create_test_data if TYPE_CHECKING: - from xarray.core.types import JoinOptions, CombineAttrsOptions + from xarray.core.types import CombineAttrsOptions, JoinOptions + def test_concat_compat() -> None: ds1 = Dataset( From dd1dd00ee50032bd595bb00409bfb07254f7e194 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 16 May 2022 20:20:06 +0200 Subject: [PATCH 17/45] add some blank lines under bullet lists in docs --- xarray/core/alignment.py | 1 + xarray/core/combine.py | 10 ++++++---- xarray/core/dataset.py | 2 ++ xarray/core/merge.py | 2 ++ 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 4c507cd18e1..c1a9192233e 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -591,6 +591,7 @@ def align( - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + copy : bool, optional If ``copy=True``, data in the return values is always copied. If ``copy=False`` and reindexing is unnecessary, or can be performed with diff --git a/xarray/core/combine.py b/xarray/core/combine.py index ccbea0520d6..f5e8680c0ba 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -714,18 +714,19 @@ def combine_by_coords( must be equal. The returned dataset then contains the combination of all non-null values. - "override": skip comparing and pick variable from first dataset + data_vars : {"minimal", "different", "all" or list of str}, optional These data variables will be concatenated together: - * "minimal": Only data variables in which the dimension already + - "minimal": Only data variables in which the dimension already appears are included. - * "different": Data variables which are not equal (ignoring + - "different": Data variables which are not equal (ignoring attributes) across all datasets are also concatenated (as well as all for which dimension already appears). Beware: this option may load the data payload of data variables into memory if they are not already loaded. - * "all": All data variables will be concatenated. - * list of str: The listed data variables will be concatenated, in + - "all": All data variables will be concatenated. + - list of str: The listed data variables will be concatenated, in addition to the "minimal" data variables. If objects are DataArrays, `data_vars` must be "all". @@ -748,6 +749,7 @@ def combine_by_coords( - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ "override"} or callable, default: "drop" A callable or a string indicating how to combine attrs of the objects being diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e8774563584..1ae5968353f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4558,6 +4558,7 @@ def merge( - 'no_conflicts': only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. + join : {"outer", "inner", "left", "right", "exact"}, optional Method for joining ``self`` and ``other`` along shared dimensions: @@ -4566,6 +4567,7 @@ def merge( - 'left': use indexes from ``self`` - 'right': use indexes from ``other`` - 'exact': error instead of aligning non-equal indexes + fill_value : scalar or dict-like, optional Value to use for newly missing values. If a dict-like, maps variable names (including coordinates) to fill values. diff --git a/xarray/core/merge.py b/xarray/core/merge.py index c4d50076604..41ab5b7efbd 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -809,6 +809,7 @@ def merge( must be equal. The returned dataset then contains the combination of all non-null values. - "override": skip comparing and pick variable from first dataset + join : {"outer", "inner", "left", "right", "exact"}, optional String indicating how to combine differing indexes in objects. @@ -821,6 +822,7 @@ def merge( - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. + fill_value : scalar or dict-like, optional Value to use for newly missing values. If a dict-like, maps variable names to fill values. Use a data array's name to From 0f1a1df85e473a83e093976296fa510443eb1a35 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 16 May 2022 20:20:46 +0200 Subject: [PATCH 18/45] add comments to overloads --- xarray/backends/api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9547eaa5fde..134354a176e 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1028,6 +1028,7 @@ def multi_file_closer(): } +# multifile=True returns writer and datastore @overload def to_netcdf( dataset: Dataset, @@ -1045,6 +1046,7 @@ def to_netcdf( ... +# path=None writes to bytes @overload def to_netcdf( dataset: Dataset, @@ -1062,6 +1064,7 @@ def to_netcdf( ... +# compute=False returns dask.Delayed @overload def to_netcdf( dataset: Dataset, @@ -1079,6 +1082,7 @@ def to_netcdf( ... +# default return None @overload def to_netcdf( dataset: Dataset, From 0d08118e5b09082a0e8fa4dcc3dfb663654b831b Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 16 May 2022 20:30:53 +0200 Subject: [PATCH 19/45] some more typing --- xarray/core/computation.py | 22 ++++++++++------------ xarray/core/mypy.py | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 12 deletions(-) create mode 100644 xarray/core/mypy.py diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 2fe258c2608..10c0ea94059 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -38,7 +38,7 @@ from .coordinates import Coordinates from .dataarray import DataArray from .dataset import Dataset - from .types import T_Xarray + from .types import T_Xarray, JoinOptions, CombineAttrsOptions _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") @@ -185,7 +185,7 @@ def _enumerate(dim): return str(alt_signature) -def result_name(objects: list) -> Any: +def result_name(objects: Sequence[Any]) -> Any: # use the same naming heuristics as pandas: # https://github.com/blaze/blaze/issues/458#issuecomment-51936356 names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects} @@ -197,7 +197,7 @@ def result_name(objects: list) -> Any: return name -def _get_coords_list(args) -> list[Coordinates]: +def _get_coords_list(args: Sequence[Any]) -> list[Coordinates]: coords_list = [] for arg in args: try: @@ -210,19 +210,16 @@ def _get_coords_list(args) -> list[Coordinates]: def build_output_coords_and_indexes( - args: list, + args: Sequence[Any], signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset(), - combine_attrs: Literal[ - "drop", "identical", "no_conflicts", "drop_conflicts", "override" - ] - | Callable[..., Any] = "override", + combine_attrs: CombineAttrsOptions = "override", ) -> tuple[list[dict[Any, Variable]], list[dict[Any, Index]]]: """Build output coordinates and indexes for an operation. Parameters ---------- - args : list + args : Sequence List of raw operation arguments. Any valid types for xarray operations are OK, e.g., scalars, Variable, DataArray, Dataset. signature : _UfuncSignature @@ -287,10 +284,10 @@ def apply_dataarray_vfunc( func, *args, signature, - join="inner", + join: JoinOptions = "inner", exclude_dims=frozenset(), keep_attrs="override", -): +) -> tuple[DataArray, ...] | DataArray: """Apply a variable level function over DataArray, Variable and/or ndarray objects. """ @@ -315,6 +312,7 @@ def apply_dataarray_vfunc( data_vars = [getattr(a, "variable", a) for a in args] result_var = func(*data_vars) + out: tuple[DataArray, ...] | DataArray if signature.num_outputs > 1: out = tuple( DataArray( @@ -849,7 +847,7 @@ def apply_ufunc( output_core_dims: Sequence[Sequence] | None = ((),), exclude_dims: AbstractSet = frozenset(), vectorize: bool = False, - join: str = "exact", + join: JoinOptions = "exact", dataset_join: str = "exact", dataset_fill_value: object = _NO_FILL_VALUE, keep_attrs: bool | str | None = None, diff --git a/xarray/core/mypy.py b/xarray/core/mypy.py new file mode 100644 index 00000000000..070c9a414af --- /dev/null +++ b/xarray/core/mypy.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +# from typing import TypeVar, Union + +# T = TypeVar("T", bound=Union[str, int]) + +# def a(x: T) -> T: +# if isinstance(x, str): +# return x +# else: +# y: int +# y = x +# return x + +# def b(x: T) -> T: +# if isinstance(x, int): +# y: int +# y = x +# return x +# else: +# return x \ No newline at end of file From c6aa64626f713c541809b6d917b243ad0f14c496 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 16 May 2022 20:31:30 +0200 Subject: [PATCH 20/45] fix absolute import --- xarray/core/concat.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index edc4f87640c..5131d99f7be 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -4,8 +4,6 @@ import pandas as pd -from xarray.core.types import CombineAttrsOptions - from . import dtypes, utils from .alignment import align from .duck_array_ops import lazy_array_equiv @@ -22,7 +20,7 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - from .types import CompatOptions, ConcatOptions, JoinOptions + from .types import CompatOptions, ConcatOptions, JoinOptions, CombineAttrsOptions @overload From 9376d36bfcc9cc5e61185352dfdfb1829d06c691 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 May 2022 18:32:37 +0000 Subject: [PATCH 21/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/computation.py | 3 +-- xarray/core/mypy.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 10c0ea94059..cba4615f295 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -15,7 +15,6 @@ Callable, Hashable, Iterable, - Literal, Mapping, Sequence, overload, @@ -38,7 +37,7 @@ from .coordinates import Coordinates from .dataarray import DataArray from .dataset import Dataset - from .types import T_Xarray, JoinOptions, CombineAttrsOptions + from .types import CombineAttrsOptions, JoinOptions, T_Xarray _NO_FILL_VALUE = utils.ReprObject("") _DEFAULT_NAME = utils.ReprObject("") diff --git a/xarray/core/mypy.py b/xarray/core/mypy.py index 070c9a414af..abd012c9f11 100644 --- a/xarray/core/mypy.py +++ b/xarray/core/mypy.py @@ -18,4 +18,4 @@ # y = x # return x # else: -# return x \ No newline at end of file +# return x From ef19bab61c5f5cc1f1aac6ae40f484094dc0d357 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Tue, 17 May 2022 09:38:11 +0200 Subject: [PATCH 22/45] Delete mypy.py whops, accidential upload --- xarray/core/mypy.py | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 xarray/core/mypy.py diff --git a/xarray/core/mypy.py b/xarray/core/mypy.py deleted file mode 100644 index abd012c9f11..00000000000 --- a/xarray/core/mypy.py +++ /dev/null @@ -1,21 +0,0 @@ -from __future__ import annotations - -# from typing import TypeVar, Union - -# T = TypeVar("T", bound=Union[str, int]) - -# def a(x: T) -> T: -# if isinstance(x, str): -# return x -# else: -# y: int -# y = x -# return x - -# def b(x: T) -> T: -# if isinstance(x, int): -# y: int -# y = x -# return x -# else: -# return x From 494b9fadcdda7f25208115f863090ef6c36d40bb Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Tue, 17 May 2022 09:51:24 +0200 Subject: [PATCH 23/45] fix typo --- xarray/core/concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index edc4f87640c..408587649a8 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -243,7 +243,7 @@ def concat( combine_attrs=combine_attrs, ) elif isinstance(first_obj, Dataset): - return _dataarray_concat( + return _dataset_concat( arrays=objs, dim=dim, data_vars=data_vars, From a4351164c42d2824282d40c8d77845dd65621c06 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Tue, 17 May 2022 09:52:44 +0200 Subject: [PATCH 24/45] fix absolute import --- xarray/core/concat.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 408587649a8..5a80dc957aa 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -4,8 +4,6 @@ import pandas as pd -from xarray.core.types import CombineAttrsOptions - from . import dtypes, utils from .alignment import align from .duck_array_ops import lazy_array_equiv @@ -22,7 +20,7 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - from .types import CompatOptions, ConcatOptions, JoinOptions + from .types import CombineAttrsOptions, CompatOptions, ConcatOptions, JoinOptions @overload From 9abb840e45d9eaf7dcd1cc419a0d609a96598bad Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Tue, 17 May 2022 09:58:19 +0200 Subject: [PATCH 25/45] fix some absolute imports --- xarray/core/dataset.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 73ea969bf76..1cab433a8f2 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -31,7 +31,6 @@ import pandas as pd import xarray as xr -from xarray.core.types import CombineAttrsOptions, CompatOptions from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings @@ -55,6 +54,7 @@ from .common import DataWithCoords, _contains_datetime_like_objects, get_chunksizes from .computation import unify_chunks from .coordinates import DatasetCoordinates, assert_coordinate_consistent +from .dataarray import DataArray from .duck_array_ops import datetime_to_numeric from .indexes import ( Index, @@ -105,9 +105,8 @@ if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore from ..backends.api import T_NetcdfEngine, T_NetcdfTypes - from .dataarray import DataArray from .merge import CoercibleMapping - from .types import ErrorOptions, ErrorOptionsWithWarn, JoinOptions, T_Xarray + from .types import ErrorOptions, ErrorOptionsWithWarn, JoinOptions, T_Xarray, CombineAttrsOptions, CompatOptions try: from dask.delayed import Delayed @@ -160,7 +159,7 @@ def _get_virtual_variable( ref_var = variables[ref_name] if _contains_datetime_like_objects(ref_var): - ref_var = xr.DataArray(ref_var) + ref_var = DataArray(ref_var) data = getattr(ref_var.dt, var_name).data else: data = getattr(ref_var, var_name).data @@ -1414,7 +1413,7 @@ def __setitem__(self, key: Hashable | list[Hashable] | Mapping, value) -> None: ) if isinstance(value, Dataset): self.update(dict(zip(key, value.data_vars.values()))) - elif isinstance(value, xr.DataArray): + elif isinstance(value, DataArray): raise ValueError("Cannot assign single DataArray to multiple keys") else: self.update(dict(zip(key, value))) @@ -1449,7 +1448,7 @@ def _setitem_check(self, key, value): "Dataset assignment only accepts DataArrays, Datasets, and scalars." ) - new_value = xr.Dataset() + new_value = Dataset() for name, var in self.items(): # test indexing try: From fc11f6f1599d2cb9ffee590fec8014ef7fbda71b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 May 2022 07:59:51 +0000 Subject: [PATCH 26/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/dataset.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1cab433a8f2..c4ccb7a2b89 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -106,7 +106,14 @@ from ..backends import AbstractDataStore, ZarrStore from ..backends.api import T_NetcdfEngine, T_NetcdfTypes from .merge import CoercibleMapping - from .types import ErrorOptions, ErrorOptionsWithWarn, JoinOptions, T_Xarray, CombineAttrsOptions, CompatOptions + from .types import ( + CombineAttrsOptions, + CompatOptions, + ErrorOptions, + ErrorOptionsWithWarn, + JoinOptions, + T_Xarray, + ) try: from dask.delayed import Delayed From b85c347db594d31d2c06c06537dff6d5c945da7a Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Tue, 17 May 2022 10:02:34 +0200 Subject: [PATCH 27/45] replace Dict by dict --- xarray/tests/test_distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 8ac87bbc807..b683ba73f75 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -2,7 +2,7 @@ import pickle import numpy as np -from typing import Any, Dict +from typing import Any import pytest from packaging.version import Version @@ -158,7 +158,7 @@ def test_dask_distributed_zarr_integration_test(loop, consolidated, compute) -> if consolidated: pytest.importorskip("zarr", minversion="2.2.1.dev2") write_kwargs = {"consolidated": True} - read_kwargs: Dict[str, Any] = {"backend_kwargs": {"consolidated": True}} + read_kwargs: dict[str, Any] = {"backend_kwargs": {"consolidated": True}} else: write_kwargs = read_kwargs = {} # type: ignore chunks = {"dim1": 4, "dim2": 3, "dim3": 5} From 820cb11866fa3844854c237707d63692bf96b2bf Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Tue, 17 May 2022 10:53:58 +0200 Subject: [PATCH 28/45] fix DataArray import --- xarray/core/dataset.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c4ccb7a2b89..c2d451d9979 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -54,7 +54,6 @@ from .common import DataWithCoords, _contains_datetime_like_objects, get_chunksizes from .computation import unify_chunks from .coordinates import DatasetCoordinates, assert_coordinate_consistent -from .dataarray import DataArray from .duck_array_ops import datetime_to_numeric from .indexes import ( Index, @@ -105,6 +104,7 @@ if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore from ..backends.api import T_NetcdfEngine, T_NetcdfTypes + from .dataarray import DataArray from .merge import CoercibleMapping from .types import ( CombineAttrsOptions, @@ -147,6 +147,8 @@ def _get_virtual_variable( objects (if possible) """ + from .dataarray import DataArray + if dim_sizes is None: dim_sizes = {} @@ -338,7 +340,7 @@ def _initialize_feasible(lb, ub): return param_defaults, bounds_defaults -class DataVariables(Mapping[Any, "DataArray"]): +class DataVariables(Mapping[Any, DataArray]): __slots__ = ("_dataset",) def __init__(self, dataset: Dataset): @@ -1283,7 +1285,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: indexes = filter_indexes_from_coords(self._indexes, set(coords)) - return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) + return xr.DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) def __copy__(self) -> Dataset: return self.copy(deep=False) @@ -1379,7 +1381,8 @@ def __setitem__(self, key: Hashable | list[Hashable] | Mapping, value) -> None: If the given value is also a dataset, select corresponding variables in the given value and in the dataset to be changed. - If value is a `DataArray`, call its `select_vars()` method, rename it + If value is a ` + from .dataarray import DataArray`, call its `select_vars()` method, rename it to `key` and merge the contents of the resulting dataset into this dataset. @@ -1387,6 +1390,8 @@ def __setitem__(self, key: Hashable | list[Hashable] | Mapping, value) -> None: ``(dims, data[, attrs])``), add it to this dataset as a new variable. """ + from .dataarray import DataArray + if utils.is_dict_like(key): # check for consistency and convert value to dataset value = self._setitem_check(key, value) From 768ba792833d5eb30407149bbc06adc4b81f9d23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 May 2022 08:55:54 +0000 Subject: [PATCH 29/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c2d451d9979..ccc410dc351 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -148,7 +148,7 @@ def _get_virtual_variable( """ from .dataarray import DataArray - + if dim_sizes is None: dim_sizes = {} @@ -1391,7 +1391,7 @@ def __setitem__(self, key: Hashable | list[Hashable] | Mapping, value) -> None: variable. """ from .dataarray import DataArray - + if utils.is_dict_like(key): # check for consistency and convert value to dataset value = self._setitem_check(key, value) From ebf1e7d433a5087453607f49f0995d1651bfa80c Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Tue, 17 May 2022 10:56:02 +0200 Subject: [PATCH 30/45] fix _dataset_concat arg name --- xarray/core/concat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 5a80dc957aa..92e81dca4e3 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -230,7 +230,7 @@ def concat( if isinstance(first_obj, DataArray): return _dataarray_concat( - arrays=objs, + objs, dim=dim, data_vars=data_vars, coords=coords, @@ -242,7 +242,7 @@ def concat( ) elif isinstance(first_obj, Dataset): return _dataset_concat( - arrays=objs, + objs, dim=dim, data_vars=data_vars, coords=coords, From d64c04ad243093e7ad117c440da452e885705a0b Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 13:13:47 +0200 Subject: [PATCH 31/45] fix DataArray not imported --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ccc410dc351..a48ba2fc567 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -340,7 +340,7 @@ def _initialize_feasible(lb, ub): return param_defaults, bounds_defaults -class DataVariables(Mapping[Any, DataArray]): +class DataVariables(Mapping[Any, "DataArray"]): __slots__ = ("_dataset",) def __init__(self, dataset: Dataset): From c5c3125b05cdded085a84f7de14e17cdcc2c4612 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 13:19:58 +0200 Subject: [PATCH 32/45] remove xr import in Dataset --- xarray/core/dataset.py | 44 ++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a48ba2fc567..a9dfcec27e4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -30,8 +30,6 @@ import numpy as np import pandas as pd -import xarray as xr - from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings from ..plot.dataset_plot import _Dataset_PlotMethods @@ -1285,7 +1283,7 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: indexes = filter_indexes_from_coords(self._indexes, set(coords)) - return xr.DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) + return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) def __copy__(self) -> Dataset: return self.copy(deep=False) @@ -1445,6 +1443,7 @@ def _setitem_check(self, key, value): to avoid leaving the dataset in a partially updated state when an error occurs. """ from .dataarray import DataArray + from .alignment import align if isinstance(value, Dataset): missing_vars = [ @@ -1497,7 +1496,7 @@ def _setitem_check(self, key, value): # check consistency of dimension sizes and dimension coordinates if isinstance(value, DataArray) or isinstance(value, Dataset): - xr.align(self[key], value, join="exact", copy=False) + align(self[key], value, join="exact", copy=False) return new_value @@ -2175,6 +2174,7 @@ def _validate_indexers( associated index is a DatetimeIndex or CFTimeIndex """ from .dataarray import DataArray + from ..coding.cftimeindex import CFTimeIndex indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims) @@ -2197,7 +2197,7 @@ def _validate_indexers( index = self._indexes[k].to_pandas_index() if isinstance(index, pd.DatetimeIndex): v = v.astype("datetime64[ns]") - elif isinstance(index, xr.CFTimeIndex): + elif isinstance(index, CFTimeIndex): v = _parse_array_of_cftime_strings(v, index.date_type) if v.ndim > 1: @@ -4263,6 +4263,8 @@ def to_stacked_array( Dimensions without coordinates: x """ + from .concat import concat + stacking_dims = tuple(dim for dim in self.dims if dim not in sample_dims) for variable in self: @@ -4293,7 +4295,7 @@ def ensure_stackable(val): # concatenate the arrays stackable_vars = [ensure_stackable(self[key]) for key in self.data_vars] - data_array = xr.concat(stackable_vars, dim=new_dim) + data_array = concat(stackable_vars, dim=new_dim) if name is not None: data_array.name = name @@ -4613,7 +4615,9 @@ def merge( -------- Dataset.update """ - other = other.to_dataset() if isinstance(other, xr.DataArray) else other + from .dataarray import DataArray + + other = other.to_dataset() if isinstance(other, DataArray) else other merge_result = dataset_merge_method( self, other, @@ -7239,6 +7243,8 @@ def polyfit( numpy.polyval xarray.polyval """ + from .dataarray import DataArray + variables = {} skipna_da = skipna @@ -7272,10 +7278,10 @@ def polyfit( rank = np.linalg.matrix_rank(lhs) if full: - rank = xr.DataArray(rank, name=xname + "matrix_rank") + rank = DataArray(rank, name=xname + "matrix_rank") variables[rank.name] = rank _sing = np.linalg.svd(lhs, compute_uv=False) - sing = xr.DataArray( + sing = DataArray( _sing, dims=(degree_dim,), coords={degree_dim: np.arange(rank - 1, -1, -1)}, @@ -7328,7 +7334,7 @@ def polyfit( # Thus a ReprObject => polyfit was called on a DataArray name = "" - coeffs = xr.DataArray( + coeffs = DataArray( coeffs / scale_da, dims=[degree_dim] + list(stacked_coords.keys()), coords={degree_dim: np.arange(order)[::-1], **stacked_coords}, @@ -7339,7 +7345,7 @@ def polyfit( variables[coeffs.name] = coeffs if full or (cov is True): - residuals = xr.DataArray( + residuals = DataArray( residuals if dims_to_stack else residuals.squeeze(), dims=list(stacked_coords.keys()), coords=stacked_coords, @@ -7360,7 +7366,7 @@ def polyfit( "The number of data points must exceed order to scale the covariance matrix." ) fac = residuals / (x.shape[0] - order) - covariance = xr.DataArray(Vbase, dims=("cov_i", "cov_j")) * fac + covariance = DataArray(Vbase, dims=("cov_i", "cov_j")) * fac variables[name + "polyfit_covariance"] = covariance return Dataset(data_vars=variables, attrs=self.attrs.copy()) @@ -8013,6 +8019,10 @@ def curvefit( """ from scipy.optimize import curve_fit + from .dataarray import DataArray, _THIS_ARRAY + from .alignment import broadcast + from .computation import apply_ufunc + if p0 is None: p0 = {} if bounds is None: @@ -8029,7 +8039,7 @@ def curvefit( if ( isinstance(coords, str) - or isinstance(coords, xr.DataArray) + or isinstance(coords, DataArray) or not isinstance(coords, Iterable) ): coords = [coords] @@ -8048,7 +8058,7 @@ def curvefit( ) # Broadcast all coords with each other - coords_ = xr.broadcast(*coords_) + coords_ = broadcast(*coords_) coords_ = [ coord.broadcast_like(self, exclude=preserved_dims) for coord in coords_ ] @@ -8083,14 +8093,14 @@ def _wrapper(Y, *coords_, **kwargs): popt, pcov = curve_fit(func, x, y, **kwargs) return popt, pcov - result = xr.Dataset() + result = Dataset() for name, da in self.data_vars.items(): - if name is xr.core.dataarray._THIS_ARRAY: + if name is _THIS_ARRAY: name = "" else: name = f"{str(name)}_" - popt, pcov = xr.apply_ufunc( + popt, pcov = apply_ufunc( _wrapper, da, *coords_, From d6b7cc20e7e6cd268c05ab73282a094731d37edf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 May 2022 11:22:11 +0000 Subject: [PATCH 33/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a9dfcec27e4..833ab6bc207 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1442,8 +1442,8 @@ def _setitem_check(self, key, value): When assigning values to a subset of a Dataset, do consistency check beforehand to avoid leaving the dataset in a partially updated state when an error occurs. """ - from .dataarray import DataArray from .alignment import align + from .dataarray import DataArray if isinstance(value, Dataset): missing_vars = [ @@ -2173,8 +2173,8 @@ def _validate_indexers( + string indexers are cast to the appropriate date type if the associated index is a DatetimeIndex or CFTimeIndex """ - from .dataarray import DataArray from ..coding.cftimeindex import CFTimeIndex + from .dataarray import DataArray indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims) @@ -8019,9 +8019,9 @@ def curvefit( """ from scipy.optimize import curve_fit - from .dataarray import DataArray, _THIS_ARRAY from .alignment import broadcast from .computation import apply_ufunc + from .dataarray import _THIS_ARRAY, DataArray if p0 is None: p0 = {} From 7367248297a914aab45842beca3c3c85970cd8bb Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 13:23:54 +0200 Subject: [PATCH 34/45] some more typing --- xarray/core/combine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/combine.py b/xarray/core/combine.py index ba38263d81c..8ff2b45d105 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -342,7 +342,7 @@ def _nested_combine( ids, fill_value=dtypes.NA, join: JoinOptions = "outer", - combine_attrs="drop", + combine_attrs: CombineAttrsOptions = "drop", ): if len(datasets) == 0: @@ -388,7 +388,7 @@ def combine_nested( coords: str = "different", fill_value: object = dtypes.NA, join: JoinOptions = "outer", - combine_attrs: str = "drop", + combine_attrs: CombineAttrsOptions = "drop", ) -> Dataset: """ Explicitly combine an N-dimensional grid of datasets into one by using a @@ -606,9 +606,9 @@ def _combine_single_variable_hypercube( fill_value=dtypes.NA, data_vars="all", coords="different", - compat="no_conflicts", - join="outer", - combine_attrs="no_conflicts", + compat: CompatOptions = "no_conflicts", + join: JoinOptions = "outer", + combine_attrs: CombineAttrsOptions = "no_conflicts", ): """ Attempt to combine a list of Datasets into a hypercube using their From c30ae552275ff870821094fb42054ebda74f2a7f Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 13:33:42 +0200 Subject: [PATCH 35/45] replace some Sequence by Iterable --- xarray/core/combine.py | 8 +++++--- xarray/core/computation.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 8ff2b45d105..fe4178eca61 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -662,14 +662,14 @@ def _combine_single_variable_hypercube( # TODO remove empty list default param after version 0.21, see PR4696 def combine_by_coords( - data_objects: Sequence[Dataset | DataArray] = [], + data_objects: Iterable[Dataset | DataArray] = [], compat: CompatOptions = "no_conflicts", data_vars: Literal["all", "minimal", "different"] | list[str] = "all", coords: str = "different", fill_value: object = dtypes.NA, join: JoinOptions = "outer", combine_attrs: CombineAttrsOptions = "no_conflicts", - datasets: Sequence[Dataset] = None, + datasets: Iterable[Dataset] = None, ) -> Dataset | DataArray: """ @@ -698,7 +698,7 @@ def combine_by_coords( Parameters ---------- - data_objects : sequence of xarray.Dataset or sequence of xarray.DataArray + data_objects : Iterable of Datasets or DataArrays Data objects to combine. compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional @@ -767,6 +767,8 @@ def combine_by_coords( If a callable, it must expect a sequence of ``attrs`` dicts and a context object as its only parameters. + datasets : Iterable of Datasets + Returns ------- combined : xarray.Dataset or xarray.DataArray diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cba4615f295..da2db39525a 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -184,7 +184,7 @@ def _enumerate(dim): return str(alt_signature) -def result_name(objects: Sequence[Any]) -> Any: +def result_name(objects: Iterable[Any]) -> Any: # use the same naming heuristics as pandas: # https://github.com/blaze/blaze/issues/458#issuecomment-51936356 names = {getattr(obj, "name", _DEFAULT_NAME) for obj in objects} @@ -196,7 +196,7 @@ def result_name(objects: Sequence[Any]) -> Any: return name -def _get_coords_list(args: Sequence[Any]) -> list[Coordinates]: +def _get_coords_list(args: Iterable[Any]) -> list[Coordinates]: coords_list = [] for arg in args: try: @@ -209,7 +209,7 @@ def _get_coords_list(args: Sequence[Any]) -> list[Coordinates]: def build_output_coords_and_indexes( - args: Sequence[Any], + args: Iterable[Any], signature: _UFuncSignature, exclude_dims: AbstractSet = frozenset(), combine_attrs: CombineAttrsOptions = "override", @@ -218,7 +218,7 @@ def build_output_coords_and_indexes( Parameters ---------- - args : Sequence + args : Iterable List of raw operation arguments. Any valid types for xarray operations are OK, e.g., scalars, Variable, DataArray, Dataset. signature : _UfuncSignature From 68b3d589a435185101951ccea4b516976a459ea1 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 14:00:51 +0200 Subject: [PATCH 36/45] fix wrong default in docstring --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 833ab6bc207..a5f304271e7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4585,7 +4585,7 @@ def merge( Value to use for newly missing values. If a dict-like, maps variable names (including coordinates) to fill values. combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \ - "override"} or callable, default: "drop" + "override"} or callable, default: "override" A callable or a string indicating how to combine attrs of the objects being merged: From 5166bc49f62f3bfc42caec12c350289959580754 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 14:01:34 +0200 Subject: [PATCH 37/45] fix docstring indentation --- xarray/core/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a5f304271e7..a8b3abc8ef3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4592,11 +4592,11 @@ def merge( - "drop": empty attrs on returned Dataset. - "identical": all attrs must be the same on every object. - "no_conflicts": attrs from all objects are combined, any that have - the same name must also have the same value. + the same name must also have the same value. - "drop_conflicts": attrs from all objects are combined, any that have - the same name but different values are dropped. + the same name but different values are dropped. - "override": skip comparing and copy attrs from the first dataset to - the result. + the result. If a callable, it must expect a sequence of ``attrs`` dicts and a context object as its only parameters. From 9543dead51876f38ef42b1c5fa6761f073045a07 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 18:43:12 +0200 Subject: [PATCH 38/45] fix overloads and type some tests --- xarray/backends/api.py | 124 ++++++++++++++------------- xarray/core/dataarray.py | 153 ++++++++++++++++++++++++++++++++-- xarray/core/dataset.py | 120 ++++++++++++++------------ xarray/tests/test_backends.py | 58 ++++++------- 4 files changed, 303 insertions(+), 152 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 110616bd448..17c81730409 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1031,16 +1031,16 @@ def multi_file_closer(): @overload def to_netcdf( dataset: Dataset, - path_or_file: str | os.PathLike | None, - mode: Literal["w", "a"], - format: T_NetcdfTypes | None, - group: str | None, - engine: T_NetcdfEngine | None, - encoding: Mapping[Hashable, Mapping[str, Any]] | None, - unlimited_dims: Iterable[Hashable] | None, - compute: bool, - multifile: Literal[True], - invalid_netcdf: bool, + path_or_file: str | os.PathLike | None = None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + multifile: Literal[True] = True, + invalid_netcdf: bool = False, ) -> tuple[ArrayWriter, AbstractDataStore]: ... @@ -1049,16 +1049,17 @@ def to_netcdf( @overload def to_netcdf( dataset: Dataset, - path_or_file: None, - mode: Literal["w", "a"], - format: T_NetcdfTypes | None, - group: str | None, - engine: T_NetcdfEngine | None, - encoding: Mapping[Hashable, Mapping[str, Any]] | None, - unlimited_dims: Iterable[Hashable] | None, - compute: bool, + path_or_file: None = None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + *, multifile: Literal[False], - invalid_netcdf: bool, + invalid_netcdf: bool = False, ) -> bytes: ... @@ -1067,16 +1068,17 @@ def to_netcdf( @overload def to_netcdf( dataset: Dataset, - path_or_file: str | os.PathLike, - mode: Literal["w", "a"], - format: T_NetcdfTypes | None, - group: str | None, - engine: T_NetcdfEngine | None, - encoding: Mapping[Hashable, Mapping[str, Any]] | None, - unlimited_dims: Iterable[Hashable] | None, + path_or_file: str | os.PathLike | None = None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, compute: Literal[False], multifile: Literal[False], - invalid_netcdf: bool, + invalid_netcdf: bool = False, ) -> Delayed: ... @@ -1086,15 +1088,16 @@ def to_netcdf( def to_netcdf( dataset: Dataset, path_or_file: str | os.PathLike, - mode: Literal["w", "a"], - format: T_NetcdfTypes | None, - group: str | None, - engine: T_NetcdfEngine | None, - encoding: Mapping[Hashable, Mapping[str, Any]] | None, - unlimited_dims: Iterable[Hashable] | None, - compute: Literal[True], + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: Literal[True] = True, + *, multifile: Literal[False], - invalid_netcdf: bool, + invalid_netcdf: bool = False, ) -> None: ... @@ -1439,47 +1442,50 @@ def check_dtype(vname, var): check_dtype(vname, var) +# compute=True returns ZarrStore @overload def to_zarr( dataset: Dataset, - store: MutableMapping | str | os.PathLike | None, - chunk_store: MutableMapping | str | os.PathLike | None, - mode: Literal["w", "w-", "a", "r+", None], - synchronizer, - group: str | None, - encoding: Mapping | None, - compute: Literal[True], - consolidated: bool | None, - append_dim: Hashable | None, - region: Mapping[str, slice] | None, - safe_chunks: bool, - storage_options: dict[str, str] | None, + store: MutableMapping | str | os.PathLike[str] | None = None, + chunk_store: MutableMapping | str | os.PathLike | None = None, + mode: Literal["w", "w-", "a", "r+", None] = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + compute: Literal[True] = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None ) -> backends.ZarrStore: ... +# computs=False returns dask.Delayed @overload def to_zarr( dataset: Dataset, - store: MutableMapping | str | os.PathLike | None, - chunk_store: MutableMapping | str | os.PathLike | None, - mode: Literal["w", "w-", "a", "r+", None], - synchronizer, - group: str | None, - encoding: Mapping | None, + store: MutableMapping | str | os.PathLike[str] | None = None, + chunk_store: MutableMapping | str | os.PathLike | None = None, + mode: Literal["w", "w-", "a", "r+", None] = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, compute: Literal[False], - consolidated: bool | None, - append_dim: Hashable | None, - region: Mapping[str, slice] | None, - safe_chunks: bool, - storage_options: dict[str, str] | None, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None ) -> Delayed: ... def to_zarr( dataset: Dataset, - store: MutableMapping | str | os.PathLike | None = None, + store: MutableMapping | str | os.PathLike[str] | None = None, chunk_store: MutableMapping | str | os.PathLike | None = None, mode: Literal["w", "w-", "a", "r+", None] = None, synchronizer=None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4fc5f7c0c5c..569e3af09c2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2,6 +2,7 @@ import datetime import warnings +from os import PathLike from typing import ( TYPE_CHECKING, Any, @@ -12,6 +13,7 @@ Mapping, Sequence, cast, + overload, ) import numpy as np @@ -78,6 +80,7 @@ except ImportError: iris_Cube = None + from ..backends.api import T_NetcdfEngine, T_NetcdfTypes from .types import ErrorOptions, ErrorOptionsWithWarn, T_DataArray, T_Xarray @@ -2891,12 +2894,139 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: isnull = pd.isnull(values) return np.ma.MaskedArray(data=values, mask=isnull, copy=copy) + # path=None writes to bytes + @overload def to_netcdf( - self, *args, **kwargs - ) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: - """Write DataArray contents to a netCDF file. + self, + path: None = None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + invalid_netcdf: bool = False, + ) -> bytes: + ... + + # default return None + @overload + def to_netcdf( + self, + path: str | PathLike, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: Literal[True] = True, + invalid_netcdf: bool = False, + ) -> None: + ... + + # compute=False returns dask.Delayed + @overload + def to_netcdf( + self, + path: str | PathLike | None = None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: + ... - All parameters are passed directly to :py:meth:`xarray.Dataset.to_netcdf`. + def to_netcdf( + self, + path: str | PathLike | None = None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + invalid_netcdf: bool = False, + ) -> bytes | Delayed | None: + """Write dataset contents to a netCDF file. + + Parameters + ---------- + path : str, path-like or file-like, optional + Path to which to save this dataset. File-like objects are only + supported by the scipy engine. If no path is provided, this + function returns the resulting netCDF file as bytes; in this case, + we need to use scipy, which does not support netCDF version 4 (the + default format becomes NETCDF3_64BIT). + mode : {"w", "a"}, default: "w" + Write ('w') or append ('a') mode. If mode='w', any existing file at + this location will be overwritten. If mode='a', existing variables + will be overwritten. + format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \ + "NETCDF3_CLASSIC"}, optional + File format for the resulting netCDF file: + + * NETCDF4: Data is stored in an HDF5 file, using netCDF4 API + features. + * NETCDF4_CLASSIC: Data is stored in an HDF5 file, using only + netCDF 3 compatible API features. + * NETCDF3_64BIT: 64-bit offset version of the netCDF 3 file format, + which fully supports 2+ GB files, but is only compatible with + clients linked against netCDF version 3.6.0 or later. + * NETCDF3_CLASSIC: The classic netCDF 3 file format. It does not + handle 2+ GB files very well. + + All formats are supported by the netCDF4-python library. + scipy.io.netcdf only supports the last two formats. + + The default format is NETCDF4 if you are saving a file to disk and + have the netCDF4-python library available. Otherwise, xarray falls + back to using scipy to write netCDF files and defaults to the + NETCDF3_64BIT format (scipy does not support netCDF4). + group : str, optional + Path to the netCDF4 group in the given file to open (only works for + format='NETCDF4'). The group(s) will be created if necessary. + engine : {"netcdf4", "scipy", "h5netcdf"}, optional + Engine to use when writing netCDF files. If not provided, the + default engine is chosen based on available dependencies, with a + preference for 'netcdf4' if writing to a file on disk. + encoding : dict, optional + Nested dictionary with variable names as keys and dictionaries of + variable specific encodings as values, e.g., + ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1, + "zlib": True}, ...}`` + + The `h5netcdf` engine supports both the NetCDF4-style compression + encoding parameters ``{"zlib": True, "complevel": 9}`` and the h5py + ones ``{"compression": "gzip", "compression_opts": 9}``. + This allows using any compression plugin installed in the HDF5 + library, e.g. LZF. + + unlimited_dims : iterable of hashable, optional + Dimension(s) that should be serialized as unlimited dimensions. + By default, no dimensions are treated as unlimited dimensions. + Note that unlimited_dims may also be set via + ``dataset.encoding["unlimited_dims"]``. + compute: bool, default: True + If true compute immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed later. + invalid_netcdf: bool, default: False + Only valid along with ``engine="h5netcdf"``. If True, allow writing + hdf5 files which are invalid netcdf as described in + https://github.com/h5netcdf/h5netcdf. + + Returns + ------- + * ``bytes`` if path is None + * ``dask.delayed.Delayed`` if compute is False + * None otherwise Notes ----- @@ -2910,7 +3040,7 @@ def to_netcdf( -------- Dataset.to_netcdf """ - from ..backends.api import DATAARRAY_NAME, DATAARRAY_VARIABLE + from ..backends.api import DATAARRAY_NAME, DATAARRAY_VARIABLE, to_netcdf if self.name is None: # If no name is set then use a generic xarray name @@ -2924,7 +3054,18 @@ def to_netcdf( # No problems with the name - so we're fine! dataset = self.to_dataset() - return dataset.to_netcdf(*args, **kwargs) + return to_netcdf( # type: ignore # mypy cannot resolve the overloads:( + dataset, + path, + mode=mode, + format=format, + group=group, + engine=engine, + encoding=encoding, + unlimited_dims=unlimited_dims, + compute=compute, + invalid_netcdf=invalid_netcdf, + ) def to_dict(self, data: bool = True) -> dict: """ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a8b3abc8ef3..e5fd5b95606 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1688,49 +1688,53 @@ def dump_to_store(self, store: AbstractDataStore, **kwargs) -> None: # with to_netcdf() dump_to_store(self, store, **kwargs) + # path=None writes to bytes @overload def to_netcdf( self, - path: None, - mode: Literal["w", "a"], - format: T_NetcdfTypes | None, - group: str | None, - engine: T_NetcdfEngine | None, - encoding: Mapping[Hashable, Mapping[str, Any]] | None, - unlimited_dims: Iterable[Hashable] | None, - compute: bool, - invalid_netcdf: bool, + path: None = None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: bool = True, + invalid_netcdf: bool = False, ) -> bytes: ... + # default return None @overload def to_netcdf( self, path: str | PathLike, - mode: Literal["w", "a"], - format: T_NetcdfTypes | None, - group: str | None, - engine: T_NetcdfEngine | None, - encoding: Mapping[Hashable, Mapping[str, Any]] | None, - unlimited_dims: Iterable[Hashable] | None, - compute: Literal[False], - invalid_netcdf: bool, - ) -> Delayed: + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + compute: Literal[True] = True, + invalid_netcdf: bool = False, + ) -> None: ... + # compute=False returns dask.Delayed @overload def to_netcdf( self, - path: str | PathLike, - mode: Literal["w", "a"], - format: T_NetcdfTypes | None, - group: str | None, - engine: T_NetcdfEngine | None, - encoding: Mapping[Hashable, Mapping[str, Any]] | None, - unlimited_dims: Iterable[Hashable] | None, - compute: Literal[True], - invalid_netcdf: bool, - ) -> None: + path: str | PathLike | None = None, + mode: Literal["w", "a"] = "w", + format: T_NetcdfTypes | None = None, + group: str | None = None, + engine: T_NetcdfEngine | None = None, + encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, + unlimited_dims: Iterable[Hashable] | None = None, + *, + compute: Literal[False], + invalid_netcdf: bool = False, + ) -> Delayed: ... def to_netcdf( @@ -1835,45 +1839,50 @@ def to_netcdf( invalid_netcdf=invalid_netcdf, ) + + # compute=True (default) returns ZarrStore @overload def to_zarr( self, - store: MutableMapping | str | PathLike | None, - chunk_store: MutableMapping | str | PathLike | None, - mode: Literal["w", "w-", "a", "r+", None], - synchronizer, - group: str | None, - encoding: Mapping | None, - compute: Literal[False], - consolidated: bool | None, - append_dim: Hashable | None, - region: Mapping[str, slice] | None, - safe_chunks: bool, - storage_options: dict[str, str] | None, - ) -> Delayed: + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: Literal["w", "w-", "a", "r+", None] = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + compute: Literal[True] = True, + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + ) -> ZarrStore: ... + # compute=False returns dask.Delayed @overload def to_zarr( self, - store: MutableMapping | str | PathLike | None, - chunk_store: MutableMapping | str | PathLike | None, - mode: Literal["w", "w-", "a", "r+", None], - synchronizer, - group: str | None, - encoding: Mapping | None, - compute: Literal[True], - consolidated: bool | None, - append_dim: Hashable | None, - region: Mapping[str, slice] | None, - safe_chunks: bool, - storage_options: dict[str, str] | None, - ) -> ZarrStore: + store: MutableMapping | str | PathLike[str] | None = None, + chunk_store: MutableMapping | str | PathLike | None = None, + mode: Literal["w", "w-", "a", "r+", None] = None, + synchronizer=None, + group: str | None = None, + encoding: Mapping | None = None, + *, + compute: Literal[False], + consolidated: bool | None = None, + append_dim: Hashable | None = None, + region: Mapping[str, slice] | None = None, + safe_chunks: bool = True, + storage_options: dict[str, str] | None = None, + ) -> Delayed: ... + def to_zarr( self, - store: MutableMapping | str | PathLike | None = None, + store: MutableMapping | str | PathLike[str] | None = None, chunk_store: MutableMapping | str | PathLike | None = None, mode: Literal["w", "w-", "a", "r+", None] = None, synchronizer=None, @@ -5138,6 +5147,7 @@ def interpolate_na( provided. - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. + use_coordinate : bool, str, default: True Specifies which index to use as the x values in the interpolation formulated as `y = f(x)`. If False, values are treated as if diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 6b6f6e462bd..0ea38a88fe0 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4946,7 +4946,7 @@ def new_dataset_and_coord_attrs(): @requires_scipy_or_netCDF4 class TestDataArrayToNetCDF: - def test_dataarray_to_netcdf_no_name(self): + def test_dataarray_to_netcdf_no_name(self) -> None: original_da = DataArray(np.arange(12).reshape((3, 4))) with create_tmp_file() as tmp: @@ -4955,7 +4955,7 @@ def test_dataarray_to_netcdf_no_name(self): with open_dataarray(tmp) as loaded_da: assert_identical(original_da, loaded_da) - def test_dataarray_to_netcdf_with_name(self): + def test_dataarray_to_netcdf_with_name(self) -> None: original_da = DataArray(np.arange(12).reshape((3, 4)), name="test") with create_tmp_file() as tmp: @@ -4964,7 +4964,7 @@ def test_dataarray_to_netcdf_with_name(self): with open_dataarray(tmp) as loaded_da: assert_identical(original_da, loaded_da) - def test_dataarray_to_netcdf_coord_name_clash(self): + def test_dataarray_to_netcdf_coord_name_clash(self) -> None: original_da = DataArray( np.arange(12).reshape((3, 4)), dims=["x", "y"], name="x" ) @@ -4975,7 +4975,7 @@ def test_dataarray_to_netcdf_coord_name_clash(self): with open_dataarray(tmp) as loaded_da: assert_identical(original_da, loaded_da) - def test_open_dataarray_options(self): + def test_open_dataarray_options(self) -> None: data = DataArray(np.arange(5), coords={"y": ("x", range(5))}, dims=["x"]) with create_tmp_file() as tmp: @@ -4986,13 +4986,13 @@ def test_open_dataarray_options(self): assert_identical(expected, loaded) @requires_scipy - def test_dataarray_to_netcdf_return_bytes(self): + def test_dataarray_to_netcdf_return_bytes(self) -> None: # regression test for GH1410 data = xr.DataArray([1, 2, 3]) output = data.to_netcdf() assert isinstance(output, bytes) - def test_dataarray_to_netcdf_no_name_pathlib(self): + def test_dataarray_to_netcdf_no_name_pathlib(self) -> None: original_da = DataArray(np.arange(12).reshape((3, 4))) with create_tmp_file() as tmp: @@ -5004,7 +5004,7 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): @requires_scipy_or_netCDF4 -def test_no_warning_from_dask_effective_get(): +def test_no_warning_from_dask_effective_get() -> None: with create_tmp_file() as tmpfile: with assert_no_warnings(): ds = Dataset() @@ -5012,7 +5012,7 @@ def test_no_warning_from_dask_effective_get(): @requires_scipy_or_netCDF4 -def test_source_encoding_always_present(): +def test_source_encoding_always_present() -> None: # Test for GH issue #2550. rnddata = np.random.randn(10) original = Dataset({"foo": ("x", rnddata)}) @@ -5030,13 +5030,12 @@ def _assert_no_dates_out_of_range_warning(record): @requires_scipy_or_netCDF4 @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) -def test_use_cftime_standard_calendar_default_in_range(calendar): +def test_use_cftime_standard_calendar_default_in_range(calendar) -> None: x = [0, 1] time = [0, 720] units_date = "2000-01-01" units = "days since 2000-01-01" - original = DataArray(x, [("time", time)], name="x") - original = original.to_dataset() + original = DataArray(x, [("time", time)], name="x").to_dataset() for v in ["x", "time"]: original[v].attrs["units"] = units original[v].attrs["calendar"] = calendar @@ -5061,14 +5060,13 @@ def test_use_cftime_standard_calendar_default_in_range(calendar): @requires_scipy_or_netCDF4 @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2500]) -def test_use_cftime_standard_calendar_default_out_of_range(calendar, units_year): +def test_use_cftime_standard_calendar_default_out_of_range(calendar, units_year) -> None: import cftime x = [0, 1] time = [0, 720] units = f"days since {units_year}-01-01" - original = DataArray(x, [("time", time)], name="x") - original = original.to_dataset() + original = DataArray(x, [("time", time)], name="x").to_dataset() for v in ["x", "time"]: original[v].attrs["units"] = units original[v].attrs["calendar"] = calendar @@ -5092,14 +5090,13 @@ def test_use_cftime_standard_calendar_default_out_of_range(calendar, units_year) @requires_scipy_or_netCDF4 @pytest.mark.parametrize("calendar", _ALL_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2000, 2500]) -def test_use_cftime_true(calendar, units_year): +def test_use_cftime_true(calendar, units_year) -> None: import cftime x = [0, 1] time = [0, 720] units = f"days since {units_year}-01-01" - original = DataArray(x, [("time", time)], name="x") - original = original.to_dataset() + original = DataArray(x, [("time", time)], name="x").to_dataset() for v in ["x", "time"]: original[v].attrs["units"] = units original[v].attrs["calendar"] = calendar @@ -5122,13 +5119,12 @@ def test_use_cftime_true(calendar, units_year): @requires_scipy_or_netCDF4 @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) -def test_use_cftime_false_standard_calendar_in_range(calendar): +def test_use_cftime_false_standard_calendar_in_range(calendar) -> None: x = [0, 1] time = [0, 720] units_date = "2000-01-01" units = "days since 2000-01-01" - original = DataArray(x, [("time", time)], name="x") - original = original.to_dataset() + original = DataArray(x, [("time", time)], name="x").to_dataset() for v in ["x", "time"]: original[v].attrs["units"] = units original[v].attrs["calendar"] = calendar @@ -5152,12 +5148,11 @@ def test_use_cftime_false_standard_calendar_in_range(calendar): @requires_scipy_or_netCDF4 @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2500]) -def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year): +def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year) -> None: x = [0, 1] time = [0, 720] units = f"days since {units_year}-01-01" - original = DataArray(x, [("time", time)], name="x") - original = original.to_dataset() + original = DataArray(x, [("time", time)], name="x").to_dataset() for v in ["x", "time"]: original[v].attrs["units"] = units original[v].attrs["calendar"] = calendar @@ -5171,12 +5166,11 @@ def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year): @requires_scipy_or_netCDF4 @pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2000, 2500]) -def test_use_cftime_false_nonstandard_calendar(calendar, units_year): +def test_use_cftime_false_nonstandard_calendar(calendar, units_year) -> None: x = [0, 1] time = [0, 720] units = f"days since {units_year}" - original = DataArray(x, [("time", time)], name="x") - original = original.to_dataset() + original = DataArray(x, [("time", time)], name="x").to_dataset() for v in ["x", "time"]: original[v].attrs["units"] = units original[v].attrs["calendar"] = calendar @@ -5244,7 +5238,7 @@ def test_extract_zarr_variable_encoding(): @requires_zarr @requires_fsspec @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") -def test_open_fsspec(): +def test_open_fsspec() -> None: import fsspec import zarr @@ -5286,7 +5280,7 @@ def test_open_fsspec(): @requires_h5netcdf @requires_netCDF4 -def test_load_single_value_h5netcdf(tmp_path): +def test_load_single_value_h5netcdf(tmp_path: Path) -> None: """Test that numeric single-element vector attributes are handled fine. At present (h5netcdf v0.8.1), the h5netcdf exposes single-valued numeric variable @@ -5311,7 +5305,7 @@ def test_load_single_value_h5netcdf(tmp_path): @pytest.mark.parametrize( "chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}] ) -def test_open_dataset_chunking_zarr(chunks, tmp_path): +def test_open_dataset_chunking_zarr(chunks, tmp_path: Path) -> None: encoded_chunks = 100 dask_arr = da.from_array( np.ones((500, 500), dtype="float64"), chunks=encoded_chunks @@ -5376,7 +5370,7 @@ def _check_guess_can_open_and_open(entrypoint, obj, engine, expected): @requires_netCDF4 -def test_netcdf4_entrypoint(tmp_path): +def test_netcdf4_entrypoint(tmp_path: Path) -> None: entrypoint = NetCDF4BackendEntrypoint() ds = create_test_data() @@ -5403,7 +5397,7 @@ def test_netcdf4_entrypoint(tmp_path): @requires_scipy -def test_scipy_entrypoint(tmp_path): +def test_scipy_entrypoint(tmp_path: Path) -> None: entrypoint = ScipyBackendEntrypoint() ds = create_test_data() @@ -5433,7 +5427,7 @@ def test_scipy_entrypoint(tmp_path): @requires_h5netcdf -def test_h5netcdf_entrypoint(tmp_path): +def test_h5netcdf_entrypoint(tmp_path: Path) -> None: entrypoint = H5netcdfBackendEntrypoint() ds = create_test_data() From b2d02a7d4e60e19ac7b7e20d10c5f6217ad12134 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 May 2022 16:45:17 +0000 Subject: [PATCH 39/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/backends/api.py | 4 ++-- xarray/core/dataset.py | 2 -- xarray/tests/test_backends.py | 4 +++- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 17c81730409..79d39f43598 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1457,7 +1457,7 @@ def to_zarr( append_dim: Hashable | None = None, region: Mapping[str, slice] | None = None, safe_chunks: bool = True, - storage_options: dict[str, str] | None = None + storage_options: dict[str, str] | None = None, ) -> backends.ZarrStore: ... @@ -1478,7 +1478,7 @@ def to_zarr( append_dim: Hashable | None = None, region: Mapping[str, slice] | None = None, safe_chunks: bool = True, - storage_options: dict[str, str] | None = None + storage_options: dict[str, str] | None = None, ) -> Delayed: ... diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e5fd5b95606..74533a824c9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1839,7 +1839,6 @@ def to_netcdf( invalid_netcdf=invalid_netcdf, ) - # compute=True (default) returns ZarrStore @overload def to_zarr( @@ -1879,7 +1878,6 @@ def to_zarr( ) -> Delayed: ... - def to_zarr( self, store: MutableMapping | str | PathLike[str] | None = None, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0ea38a88fe0..aeecf968582 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5060,7 +5060,9 @@ def test_use_cftime_standard_calendar_default_in_range(calendar) -> None: @requires_scipy_or_netCDF4 @pytest.mark.parametrize("calendar", _STANDARD_CALENDARS) @pytest.mark.parametrize("units_year", [1500, 2500]) -def test_use_cftime_standard_calendar_default_out_of_range(calendar, units_year) -> None: +def test_use_cftime_standard_calendar_default_out_of_range( + calendar, units_year +) -> None: import cftime x = [0, 1] From 502f7c15ed8aacb34be30f5305ec25ffad8998fe Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 19:05:15 +0200 Subject: [PATCH 40/45] fix open_mfdataset typing --- xarray/backends/api.py | 7 ++-- xarray/core/types.py | 12 +++++- xarray/tests/test_backends.py | 74 +++++++++++++++++------------------ 3 files changed, 52 insertions(+), 41 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 17c81730409..ca9ea1c1aab 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -45,6 +45,7 @@ Delayed = None # type: ignore from ..core.types import CombineAttrsOptions, CompatOptions, JoinOptions from .common import BackendEntrypoint + from ..core.types import NestedSequence T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] T_Engine = Union[ @@ -722,7 +723,7 @@ def open_dataarray( def open_mfdataset( - paths: str | Iterable[str | os.PathLike], + paths: str | NestedSequence[str | os.PathLike], chunks: T_Chunks = None, concat_dim: str | DataArray @@ -908,8 +909,8 @@ def open_mfdataset( ), expand=False, ) - paths = fs.glob(fs._strip_protocol(paths)) # finds directories - paths = [fs.get_mapper(path) for path in paths] + tmp_paths = fs.glob(fs._strip_protocol(paths)) # finds directories + paths = [fs.get_mapper(path) for path in tmp_paths] elif is_remote_uri(paths): raise ValueError( "cannot do wild-card matching for paths that are remote URLs " diff --git a/xarray/core/types.py b/xarray/core/types.py index dc325a986be..7c239481f49 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, Sequence import numpy as np @@ -45,3 +45,13 @@ Callable[..., Any], ] JoinOptions = Literal["outer", "inner", "left", "right", "exact", "override"] + +# TODO: Wait until mypy supports recursive objects in combination with typevars +_T = TypeVar("_T") +NestedSequence = Union[ + _T, + Sequence[_T], + Sequence[Sequence[_T]], + Sequence[Sequence[Sequence[_T]]], + Sequence[Sequence[Sequence[Sequence[_T]]]], +] \ No newline at end of file diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0ea38a88fe0..db61764f521 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3434,7 +3434,7 @@ def test_dataset_caching(self): actual.foo.values # no caching assert not actual.foo.variable._in_memory - def test_open_mfdataset(self): + def test_open_mfdataset(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: @@ -3457,14 +3457,14 @@ def test_open_mfdataset(self): open_mfdataset("http://some/remote/uri") @requires_fsspec - def test_open_mfdataset_no_files(self): + def test_open_mfdataset_no_files(self) -> None: pytest.importorskip("aiobotocore") # glob is attempted as of #4823, but finds no files with pytest.raises(OSError, match=r"no files"): open_mfdataset("http://some/remote/uri", engine="zarr") - def test_open_mfdataset_2d(self): + def test_open_mfdataset_2d(self) -> None: original = Dataset({"foo": (["x", "y"], np.random.randn(10, 8))}) with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: @@ -3493,7 +3493,7 @@ def test_open_mfdataset_2d(self): (2, 2, 2, 2), ) - def test_open_mfdataset_pathlib(self): + def test_open_mfdataset_pathlib(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: @@ -3506,7 +3506,7 @@ def test_open_mfdataset_pathlib(self): ) as actual: assert_identical(original, actual) - def test_open_mfdataset_2d_pathlib(self): + def test_open_mfdataset_2d_pathlib(self) -> None: original = Dataset({"foo": (["x", "y"], np.random.randn(10, 8))}) with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: @@ -3527,7 +3527,7 @@ def test_open_mfdataset_2d_pathlib(self): ) as actual: assert_identical(original, actual) - def test_open_mfdataset_2(self): + def test_open_mfdataset_2(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: @@ -3539,7 +3539,7 @@ def test_open_mfdataset_2(self): ) as actual: assert_identical(original, actual) - def test_attrs_mfdataset(self): + def test_attrs_mfdataset(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: @@ -3559,7 +3559,7 @@ def test_attrs_mfdataset(self): with pytest.raises(AttributeError, match=r"no attribute"): actual.test2 - def test_open_mfdataset_attrs_file(self): + def test_open_mfdataset_attrs_file(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) with create_tmp_files(2) as (tmp1, tmp2): ds1 = original.isel(x=slice(5)) @@ -3576,7 +3576,7 @@ def test_open_mfdataset_attrs_file(self): # attributes from ds1 are not retained, e.g., assert "test1" not in actual.attrs - def test_open_mfdataset_attrs_file_path(self): + def test_open_mfdataset_attrs_file_path(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) with create_tmp_files(2) as (tmp1, tmp2): tmp1 = Path(tmp1) @@ -3595,7 +3595,7 @@ def test_open_mfdataset_attrs_file_path(self): # attributes from ds1 are not retained, e.g., assert "test1" not in actual.attrs - def test_open_mfdataset_auto_combine(self): + def test_open_mfdataset_auto_combine(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)}) with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: @@ -3605,7 +3605,7 @@ def test_open_mfdataset_auto_combine(self): with open_mfdataset([tmp2, tmp1], combine="by_coords") as actual: assert_identical(original, actual) - def test_open_mfdataset_raise_on_bad_combine_args(self): + def test_open_mfdataset_raise_on_bad_combine_args(self) -> None: # Regression test for unhelpful error shown in #5230 original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)}) with create_tmp_file() as tmp1: @@ -3616,7 +3616,7 @@ def test_open_mfdataset_raise_on_bad_combine_args(self): open_mfdataset([tmp1, tmp2], concat_dim="x") @pytest.mark.xfail(reason="mfdataset loses encoding currently.") - def test_encoding_mfdataset(self): + def test_encoding_mfdataset(self) -> None: original = Dataset( { "foo": ("t", np.random.randn(10)), @@ -3638,7 +3638,7 @@ def test_encoding_mfdataset(self): assert actual.t.encoding["units"] == ds1.t.encoding["units"] assert actual.t.encoding["units"] != ds2.t.encoding["units"] - def test_preprocess_mfdataset(self): + def test_preprocess_mfdataset(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) @@ -3652,7 +3652,7 @@ def preprocess(ds): ) as actual: assert_identical(expected, actual) - def test_save_mfdataset_roundtrip(self): + def test_save_mfdataset_roundtrip(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] with create_tmp_file() as tmp1: @@ -3663,20 +3663,20 @@ def test_save_mfdataset_roundtrip(self): ) as actual: assert_identical(actual, original) - def test_save_mfdataset_invalid(self): + def test_save_mfdataset_invalid(self) -> None: ds = Dataset() with pytest.raises(ValueError, match=r"cannot use mode"): save_mfdataset([ds, ds], ["same", "same"]) with pytest.raises(ValueError, match=r"same length"): save_mfdataset([ds, ds], ["only one path"]) - def test_save_mfdataset_invalid_dataarray(self): + def test_save_mfdataset_invalid_dataarray(self) -> None: # regression test for GH1555 da = DataArray([1, 2]) with pytest.raises(TypeError, match=r"supports writing Dataset"): save_mfdataset([da], ["dataarray"]) - def test_save_mfdataset_pathlib_roundtrip(self): + def test_save_mfdataset_pathlib_roundtrip(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] with create_tmp_file() as tmp1: @@ -3689,7 +3689,7 @@ def test_save_mfdataset_pathlib_roundtrip(self): ) as actual: assert_identical(actual, original) - def test_open_and_do_math(self): + def test_open_and_do_math(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) @@ -3697,7 +3697,7 @@ def test_open_and_do_math(self): actual = 1.0 * ds assert_allclose(original, actual, decode_bytes=False) - def test_open_mfdataset_concat_dim_none(self): + def test_open_mfdataset_concat_dim_none(self) -> None: with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: data = Dataset({"x": 0}) @@ -3708,7 +3708,7 @@ def test_open_mfdataset_concat_dim_none(self): ) as actual: assert_identical(data, actual) - def test_open_mfdataset_concat_dim_default_none(self): + def test_open_mfdataset_concat_dim_default_none(self) -> None: with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: data = Dataset({"x": 0}) @@ -3717,7 +3717,7 @@ def test_open_mfdataset_concat_dim_default_none(self): with open_mfdataset([tmp1, tmp2], combine="nested") as actual: assert_identical(data, actual) - def test_open_dataset(self): + def test_open_dataset(self) -> None: original = Dataset({"foo": ("x", np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) @@ -3731,7 +3731,7 @@ def test_open_dataset(self): assert isinstance(actual.foo.variable.data, np.ndarray) assert_identical(original, actual) - def test_open_single_dataset(self): + def test_open_single_dataset(self) -> None: # Test for issue GH #1988. This makes sure that the # concat_dim is utilized when specified in open_mfdataset(). rnddata = np.random.randn(10) @@ -3745,7 +3745,7 @@ def test_open_single_dataset(self): with open_mfdataset([tmp], concat_dim=dim, combine="nested") as actual: assert_identical(expected, actual) - def test_open_multi_dataset(self): + def test_open_multi_dataset(self) -> None: # Test for issue GH #1988 and #2647. This makes sure that the # concat_dim is utilized when specified in open_mfdataset(). # The additional wrinkle is to ensure that a length greater @@ -3770,7 +3770,7 @@ def test_open_multi_dataset(self): ) as actual: assert_identical(expected, actual) - def test_dask_roundtrip(self): + def test_dask_roundtrip(self) -> None: with create_tmp_file() as tmp: data = create_test_data() data.to_netcdf(tmp) @@ -3782,7 +3782,7 @@ def test_dask_roundtrip(self): with open_dataset(tmp2) as on_disk: assert_identical(data, on_disk) - def test_deterministic_names(self): + def test_deterministic_names(self) -> None: with create_tmp_file() as tmp: data = create_test_data() data.to_netcdf(tmp) @@ -3795,7 +3795,7 @@ def test_deterministic_names(self): assert dask_name[:13] == "open_dataset-" assert original_names == repeat_names - def test_dataarray_compute(self): + def test_dataarray_compute(self) -> None: # Test DataArray.compute() on dask backend. # The test for Dataset.compute() is already in DatasetIOBase; # however dask is the only tested backend which supports DataArrays @@ -3806,7 +3806,7 @@ def test_dataarray_compute(self): assert_allclose(actual, computed, decode_bytes=False) @pytest.mark.xfail - def test_save_mfdataset_compute_false_roundtrip(self): + def test_save_mfdataset_compute_false_roundtrip(self) -> None: from dask.delayed import Delayed original = Dataset({"foo": ("x", np.random.randn(10))}).chunk() @@ -3823,7 +3823,7 @@ def test_save_mfdataset_compute_false_roundtrip(self): ) as actual: assert_identical(actual, original) - def test_load_dataset(self): + def test_load_dataset(self) -> None: with create_tmp_file() as tmp: original = Dataset({"foo": ("x", np.random.randn(10))}) original.to_netcdf(tmp) @@ -3831,7 +3831,7 @@ def test_load_dataset(self): # this would fail if we used open_dataset instead of load_dataset ds.to_netcdf(tmp) - def test_load_dataarray(self): + def test_load_dataarray(self) -> None: with create_tmp_file() as tmp: original = Dataset({"foo": ("x", np.random.randn(10))}) original.to_netcdf(tmp) @@ -3844,7 +3844,7 @@ def test_load_dataarray(self): ON_WINDOWS, reason="counting number of tasks in graph fails on windows for some reason", ) - def test_inline_array(self): + def test_inline_array(self) -> None: with create_tmp_file() as tmp: original = Dataset({"foo": ("x", np.random.randn(10))}) original.to_netcdf(tmp) @@ -3853,13 +3853,13 @@ def test_inline_array(self): def num_graph_nodes(obj): return len(obj.__dask_graph__()) - not_inlined = open_dataset(tmp, inline_array=False, chunks=chunks) - inlined = open_dataset(tmp, inline_array=True, chunks=chunks) - assert num_graph_nodes(inlined) < num_graph_nodes(not_inlined) + not_inlined_ds = open_dataset(tmp, inline_array=False, chunks=chunks) + inlined_ds = open_dataset(tmp, inline_array=True, chunks=chunks) + assert num_graph_nodes(inlined_ds) < num_graph_nodes(not_inlined_ds) - not_inlined = open_dataarray(tmp, inline_array=False, chunks=chunks) - inlined = open_dataarray(tmp, inline_array=True, chunks=chunks) - assert num_graph_nodes(inlined) < num_graph_nodes(not_inlined) + not_inlined_da = open_dataarray(tmp, inline_array=False, chunks=chunks) + inlined_da = open_dataarray(tmp, inline_array=True, chunks=chunks) + assert num_graph_nodes(inlined_da) < num_graph_nodes(not_inlined_da) @requires_scipy_or_netCDF4 @@ -4340,7 +4340,7 @@ def create_tmp_geotiff( @requires_rasterio class TestRasterio: @requires_scipy_or_netCDF4 - def test_serialization(self): + def test_serialization(self) -> None: with create_tmp_geotiff(additional_attrs={}) as (tmp_file, expected): # Write it to a netcdf and read again (roundtrip) with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda: From f2afb6f0d39350ded7e92c2271f1bf37cf277463 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 19:07:08 +0200 Subject: [PATCH 41/45] minor update of docstring --- xarray/backends/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2adb875211e..8cdd13a579c 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -757,7 +757,7 @@ def open_mfdataset( Parameters ---------- - paths : str or Iterable of paths + paths : str or nested sequence of paths Either a string glob in the form ``"path/to/my/files/*.nc"`` or an explicit list of files to open. Paths can be given as strings or as pathlib Paths. If concatenation along more than one dimension is desired, then ``paths`` must be a From d8e39a0c598d39c195884ad419c08b5c98aa9d90 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 May 2022 17:07:17 +0000 Subject: [PATCH 42/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/backends/api.py | 8 ++++++-- xarray/core/types.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2adb875211e..006e2922c2a 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -43,9 +43,13 @@ from dask.delayed import Delayed except ImportError: Delayed = None # type: ignore - from ..core.types import CombineAttrsOptions, CompatOptions, JoinOptions + from ..core.types import ( + CombineAttrsOptions, + CompatOptions, + JoinOptions, + NestedSequence, + ) from .common import BackendEntrypoint - from ..core.types import NestedSequence T_NetcdfEngine = Literal["netcdf4", "scipy", "h5netcdf"] T_Engine = Union[ diff --git a/xarray/core/types.py b/xarray/core/types.py index 7c239481f49..5acf4f7b587 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, Sequence +from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence, TypeVar, Union import numpy as np @@ -54,4 +54,4 @@ Sequence[Sequence[_T]], Sequence[Sequence[Sequence[_T]]], Sequence[Sequence[Sequence[Sequence[_T]]]], -] \ No newline at end of file +] From 5043f85818c560ca3697f8e45e397b3abafe81bf Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 19:12:53 +0200 Subject: [PATCH 43/45] remove uneccesary import --- xarray/core/dataarray.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 569e3af09c2..8b89eba9c40 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -19,7 +19,6 @@ import numpy as np import pandas as pd -from ..backends.common import AbstractDataStore, ArrayWriter from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex from ..plot.plot import _PlotMethods From fd400e5f3b64d0f4cf7a4dc5611a58b1111dce3b Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 20:30:53 +0200 Subject: [PATCH 44/45] fix overloads of to_netcdf --- xarray/backends/api.py | 15 +++++++-------- xarray/core/dataarray.py | 3 ++- xarray/core/dataset.py | 5 +++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index c1228f1379a..5dd486e952e 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1044,7 +1044,8 @@ def to_netcdf( encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, compute: bool = True, - multifile: Literal[True] = True, + *, + multifile: Literal[True], invalid_netcdf: bool = False, ) -> tuple[ArrayWriter, AbstractDataStore]: ... @@ -1062,8 +1063,7 @@ def to_netcdf( encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, compute: bool = True, - *, - multifile: Literal[False], + multifile: Literal[False] = False, invalid_netcdf: bool = False, ) -> bytes: ... @@ -1073,7 +1073,7 @@ def to_netcdf( @overload def to_netcdf( dataset: Dataset, - path_or_file: str | os.PathLike | None = None, + path_or_file: str | os.PathLike, mode: Literal["w", "a"] = "w", format: T_NetcdfTypes | None = None, group: str | None = None, @@ -1082,7 +1082,7 @@ def to_netcdf( unlimited_dims: Iterable[Hashable] | None = None, *, compute: Literal[False], - multifile: Literal[False], + multifile: Literal[False] = False, invalid_netcdf: bool = False, ) -> Delayed: ... @@ -1100,8 +1100,7 @@ def to_netcdf( encoding: Mapping[Hashable, Mapping[str, Any]] | None = None, unlimited_dims: Iterable[Hashable] | None = None, compute: Literal[True] = True, - *, - multifile: Literal[False], + multifile: Literal[False] = False, invalid_netcdf: bool = False, ) -> None: ... @@ -1467,7 +1466,7 @@ def to_zarr( ... -# computs=False returns dask.Delayed +# compute=False returns dask.Delayed @overload def to_zarr( dataset: Dataset, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8b89eba9c40..35c0aab3fb8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2929,7 +2929,7 @@ def to_netcdf( @overload def to_netcdf( self, - path: str | PathLike | None = None, + path: str | PathLike, mode: Literal["w", "a"] = "w", format: T_NetcdfTypes | None = None, group: str | None = None, @@ -3063,6 +3063,7 @@ def to_netcdf( encoding=encoding, unlimited_dims=unlimited_dims, compute=compute, + multifile=False, invalid_netcdf=invalid_netcdf, ) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 74533a824c9..2e434caa153 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1724,7 +1724,7 @@ def to_netcdf( @overload def to_netcdf( self, - path: str | PathLike | None = None, + path: str | PathLike, mode: Literal["w", "a"] = "w", format: T_NetcdfTypes | None = None, group: str | None = None, @@ -1829,13 +1829,14 @@ def to_netcdf( return to_netcdf( # type: ignore # mypy cannot resolve the overloads:( self, path, - mode, + mode=mode, format=format, group=group, engine=engine, encoding=encoding, unlimited_dims=unlimited_dims, compute=compute, + multifile=False, invalid_netcdf=invalid_netcdf, ) From e2777279a2ed8c1f986c98a0e8d571af35f54fa1 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 17 May 2022 20:33:45 +0200 Subject: [PATCH 45/45] minor docstring update --- xarray/core/dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 2e434caa153..8cf5138c259 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1821,6 +1821,10 @@ def to_netcdf( * ``bytes`` if path is None * ``dask.delayed.Delayed`` if compute is False * None otherwise + + See Also + -------- + DataArray.to_netcdf """ if encoding is None: encoding = {}