From a2f521fa0e3c592d90f88f13e3c7e43231b3258d Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Tue, 10 Aug 2021 22:11:54 -0700 Subject: [PATCH 1/4] Change typing to allow str keys --- xarray/core/dataset.py | 2 +- xarray/tests/test_dataset.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5f5c01ad4c9..eaab8588b6f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5137,7 +5137,7 @@ def apply( return self.map(func, keep_attrs, args, **kwargs) def assign( - self, variables: Mapping[Hashable, Any] = None, **variables_kwargs: Hashable + self, variables: Mapping[Any, Any] = None, **variables_kwargs: Hashable ) -> "Dataset": """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 02d27ade161..abf14e41bde 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6795,3 +6795,12 @@ def test_from_pint_wrapping_dask(self): result = ds.as_numpy() expected = xr.Dataset({"a": ("x", arr)}, coords={"lat": ("x", arr * 2)}) assert_identical(result, expected) + + +def test_string_keys_typing() -> None: + """Tests that string keys to `variables` are permitted by mypy""" + + da = xr.DataArray(np.arange(10), dims=["x"]) + ds = xr.Dataset(dict(x=da)) + mapping = {"y": da} + ds.assign(variables=mapping) From e9168b798d183ccb1f477313ce458922802012af Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Tue, 10 Aug 2021 22:47:18 -0700 Subject: [PATCH 2/4] Change all incoming Mapping types --- xarray/core/accessor_str.py | 4 ++-- xarray/core/common.py | 8 +++---- xarray/core/computation.py | 2 +- xarray/core/coordinates.py | 10 ++++---- xarray/core/dataarray.py | 24 +++++++++---------- xarray/core/dataset.py | 46 ++++++++++++++++++------------------- xarray/core/merge.py | 4 ++-- xarray/core/utils.py | 2 +- xarray/core/variable.py | 6 ++--- 9 files changed, 53 insertions(+), 53 deletions(-) diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index e3c35d6e4b6..0f0a256b77a 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -114,7 +114,7 @@ def _apply_str_ufunc( obj: Any, dtype: Union[str, np.dtype, Type] = None, output_core_dims: Union[list, tuple] = ((),), - output_sizes: Mapping[Hashable, int] = None, + output_sizes: Mapping[Any, int] = None, func_args: Tuple = (), func_kwargs: Mapping = {}, ) -> Any: @@ -227,7 +227,7 @@ def _apply( func: Callable, dtype: Union[str, np.dtype, Type] = None, output_core_dims: Union[list, tuple] = ((),), - output_sizes: Mapping[Hashable, int] = None, + output_sizes: Mapping[Any, int] = None, func_args: Tuple = (), func_kwargs: Mapping = {}, ) -> Any: diff --git a/xarray/core/common.py b/xarray/core/common.py index ab822f576d3..d3001532aa0 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -818,7 +818,7 @@ def weighted( def rolling( self, - dim: Mapping[Hashable, int] = None, + dim: Mapping[Any, int] = None, min_periods: int = None, center: Union[bool, Mapping[Hashable, bool]] = False, **window_kwargs: int, @@ -892,7 +892,7 @@ def rolling( def rolling_exp( self, - window: Mapping[Hashable, int] = None, + window: Mapping[Any, int] = None, window_type: str = "span", **window_kwargs, ): @@ -933,7 +933,7 @@ def rolling_exp( def coarsen( self, - dim: Mapping[Hashable, int] = None, + dim: Mapping[Any, int] = None, boundary: str = "exact", side: Union[str, Mapping[Hashable, str]] = "left", coord_func: str = "mean", @@ -1009,7 +1009,7 @@ def coarsen( def resample( self, - indexer: Mapping[Hashable, str] = None, + indexer: Mapping[Any, str] = None, skipna=None, closed: str = None, label: str = None, diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cd9e22d90db..5bfb14793bb 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -400,7 +400,7 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( - variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable] + variables: Dict[Hashable, Variable], coord_variables: Mapping[Any, Variable] ) -> "Dataset": """Create a dataset as quickly as possible. diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 767b76d0d12..56afdb9774a 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -158,7 +158,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: return pd.MultiIndex(level_list, code_list, names=names) - def update(self, other: Mapping[Hashable, Any]) -> None: + def update(self, other: Mapping[Any, Any]) -> None: other_vars = getattr(other, "variables", other) coords, indexes = merge_coords( [self.variables, other_vars], priority_arg=1, indexes=self.xindexes @@ -270,7 +270,7 @@ def to_dataset(self) -> "Dataset": return self._data._copy_listed(names) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: from .dataset import calculate_dimensions @@ -333,7 +333,7 @@ def __getitem__(self, key: Hashable) -> "DataArray": return self._data._getitem_coord(key) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: from .dataset import calculate_dimensions @@ -376,7 +376,7 @@ def _ipython_key_completions_(self): def assert_coordinate_consistent( - obj: Union["DataArray", "Dataset"], coords: Mapping[Hashable, Variable] + obj: Union["DataArray", "Dataset"], coords: Mapping[Any, Variable] ) -> None: """Make sure the dimension coordinate of obj is consistent with coords. @@ -394,7 +394,7 @@ def assert_coordinate_consistent( def remap_label_indexers( obj: Union["DataArray", "Dataset"], - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance=None, **indexers_kwargs: Any, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index cb2c4d30a69..94d70f80408 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -468,7 +468,7 @@ def _replace_maybe_drop_dims( ) return self._replace(variable, coords, name, indexes=indexes) - def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray": + def _overwrite_indexes(self, indexes: Mapping[Any, Any]) -> "DataArray": if not len(indexes): return self coords = self._coords.copy() @@ -799,7 +799,7 @@ def attrs(self) -> Dict[Hashable, Any]: return self.variable.attrs @attrs.setter - def attrs(self, value: Mapping[Hashable, Any]) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: # Disable type checking to work around mypy bug - see mypy#4167 self.variable.attrs = value # type: ignore[assignment] @@ -810,7 +810,7 @@ def encoding(self) -> Dict[Hashable, Any]: return self.variable.encoding @encoding.setter - def encoding(self, value: Mapping[Hashable, Any]) -> None: + def encoding(self, value: Mapping[Any, Any]) -> None: self.variable.encoding = value @property @@ -1122,7 +1122,7 @@ def chunk( def isel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, drop: bool = False, missing_dims: str = "raise", **indexers_kwargs: Any, @@ -1205,7 +1205,7 @@ def isel( def sel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance=None, drop: bool = False, @@ -1510,7 +1510,7 @@ def reindex_like( def reindex( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance=None, copy: bool = True, @@ -1603,7 +1603,7 @@ def reindex( def interp( self, - coords: Mapping[Hashable, Any] = None, + coords: Mapping[Any, Any] = None, method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, @@ -1827,7 +1827,7 @@ def rename( return self._replace(name=new_name_or_name_dict) def swap_dims( - self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + self, dims_dict: Mapping[Any, Hashable] = None, **dims_kwargs ) -> "DataArray": """Returns a new DataArray with swapped dimensions. @@ -2345,7 +2345,7 @@ def drop( def drop_sel( self, - labels: Mapping[Hashable, Any] = None, + labels: Mapping[Any, Any] = None, *, errors: str = "raise", **labels_kwargs, @@ -3175,7 +3175,7 @@ def diff(self, dim: Hashable, n: int = 1, label: Hashable = "upper") -> "DataArr def shift( self, - shifts: Mapping[Hashable, int] = None, + shifts: Mapping[Any, int] = None, fill_value: Any = dtypes.NA, **shifts_kwargs: int, ) -> "DataArray": @@ -3222,7 +3222,7 @@ def shift( def roll( self, - shifts: Mapping[Hashable, int] = None, + shifts: Mapping[Any, int] = None, roll_coords: bool = None, **shifts_kwargs: int, ) -> "DataArray": @@ -4445,7 +4445,7 @@ def argmax( def query( self, - queries: Mapping[Hashable, Any] = None, + queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, missing_dims: str = "raise", diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index eaab8588b6f..5b0d577b297 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -186,7 +186,7 @@ def _get_virtual_variable( return ref_name, var_name, virtual_var -def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashable, int]: +def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, int]: """Calculate the dimensions corresponding to a set of variables. Returns dictionary mapping from dimension names to sizes. Raises ValueError @@ -214,7 +214,7 @@ def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashabl def merge_indexes( indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]], - variables: Mapping[Hashable, Variable], + variables: Mapping[Any, Variable], coord_names: Set[Hashable], append: bool = False, ) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]: @@ -298,9 +298,9 @@ def merge_indexes( def split_indexes( dims_or_levels: Union[Hashable, Sequence[Hashable]], - variables: Mapping[Hashable, Variable], + variables: Mapping[Any, Variable], coord_names: Set[Hashable], - level_coords: Mapping[Hashable, Hashable], + level_coords: Mapping[Any, Hashable], drop: bool = False, ) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]: """Extract (multi-)indexes (levels) as variables. @@ -560,7 +560,7 @@ class _LocIndexer: def __init__(self, dataset: "Dataset"): self.dataset = dataset - def __getitem__(self, key: Mapping[Hashable, Any]) -> "Dataset": + def __getitem__(self, key: Mapping[Any, Any]) -> "Dataset": if not utils.is_dict_like(key): raise TypeError("can only lookup dictionaries from Dataset.loc") return self.dataset.sel(key) @@ -731,9 +731,9 @@ def __init__( self, # could make a VariableArgs to use more generally, and refine these # categories - data_vars: Mapping[Hashable, Any] = None, - coords: Mapping[Hashable, Any] = None, - attrs: Mapping[Hashable, Any] = None, + data_vars: Mapping[Any, Any] = None, + coords: Mapping[Any, Any] = None, + attrs: Mapping[Any, Any] = None, ): # TODO(shoyer): expose indexes as a public argument in __init__ @@ -794,7 +794,7 @@ def attrs(self) -> Dict[Hashable, Any]: return self._attrs @attrs.setter - def attrs(self, value: Mapping[Hashable, Any]) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property @@ -2165,7 +2165,7 @@ def chunk( return self._replace(variables) def _validate_indexers( - self, indexers: Mapping[Hashable, Any], missing_dims: str = "raise" + self, indexers: Mapping[Any, Any], missing_dims: str = "raise" ) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]: """Here we make sure + indexer has a valid keys @@ -2209,7 +2209,7 @@ def _validate_indexers( yield k, v def _validate_interp_indexers( - self, indexers: Mapping[Hashable, Any] + self, indexers: Mapping[Any, Any] ) -> Iterator[Tuple[Hashable, Variable]]: """Variant of _validate_indexers to be used for interpolation""" for k, v in self._validate_indexers(indexers): @@ -2270,7 +2270,7 @@ def _get_indexers_coords_and_indexes(self, indexers): def isel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, drop: bool = False, missing_dims: str = "raise", **indexers_kwargs: Any, @@ -2362,7 +2362,7 @@ def isel( def _isel_fancy( self, - indexers: Mapping[Hashable, Any], + indexers: Mapping[Any, Any], *, drop: bool, missing_dims: str = "raise", @@ -2404,7 +2404,7 @@ def _isel_fancy( def sel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance: Number = None, drop: bool = False, @@ -2708,7 +2708,7 @@ def reindex_like( def reindex( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance: Number = None, copy: bool = True, @@ -2918,7 +2918,7 @@ def reindex( def _reindex( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance: Number = None, copy: bool = True, @@ -2952,7 +2952,7 @@ def _reindex( def interp( self, - coords: Mapping[Hashable, Any] = None, + coords: Mapping[Any, Any] = None, method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, @@ -3321,7 +3321,7 @@ def _rename_all(self, name_dict, dims_dict): def rename( self, - name_dict: Mapping[Hashable, Hashable] = None, + name_dict: Mapping[Any, Hashable] = None, **names: Hashable, ) -> "Dataset": """Returns a new object with renamed variables and dimensions. @@ -3362,7 +3362,7 @@ def rename( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename_dims( - self, dims_dict: Mapping[Hashable, Hashable] = None, **dims: Hashable + self, dims_dict: Mapping[Any, Hashable] = None, **dims: Hashable ) -> "Dataset": """Returns a new object with renamed dimensions only. @@ -3407,7 +3407,7 @@ def rename_dims( return self._replace(variables, coord_names, dims=sizes, indexes=indexes) def rename_vars( - self, name_dict: Mapping[Hashable, Hashable] = None, **names: Hashable + self, name_dict: Mapping[Any, Hashable] = None, **names: Hashable ) -> "Dataset": """Returns a new object with renamed variables including coordinates @@ -3445,7 +3445,7 @@ def rename_vars( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def swap_dims( - self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + self, dims_dict: Mapping[Any, Hashable] = None, **dims_kwargs ) -> "Dataset": """Returns a new object with swapped dimensions. @@ -5313,7 +5313,7 @@ def to_pandas(self) -> Union[pd.Series, pd.DataFrame]: "Please use Dataset.to_dataframe() instead." % len(self.dims) ) - def _to_dataframe(self, ordered_dims: Mapping[Hashable, int]): + def _to_dataframe(self, ordered_dims: Mapping[Any, int]): columns = [k for k in self.variables if k not in self.dims] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) @@ -7373,7 +7373,7 @@ def argmax(self, dim=None, **kwargs): def query( self, - queries: Mapping[Hashable, Any] = None, + queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, missing_dims: str = "raise", diff --git a/xarray/core/merge.py b/xarray/core/merge.py index db5b95fd415..a079197d344 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -170,7 +170,7 @@ def _assert_compat_valid(compat): def merge_collected( grouped: Dict[Hashable, List[MergeElement]], - prioritized: Mapping[Hashable, MergeElement] = None, + prioritized: Mapping[Any, MergeElement] = None, compat: str = "minimal", combine_attrs="override", ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: @@ -319,7 +319,7 @@ def collect_from_coordinates( def merge_coordinates_without_align( objects: "List[Coordinates]", - prioritized: Mapping[Hashable, MergeElement] = None, + prioritized: Mapping[Any, MergeElement] = None, exclude_dims: AbstractSet = frozenset(), combine_attrs: str = "override", ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index a139d2ef10a..57b7035c940 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -816,7 +816,7 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: def drop_dims_from_indexers( - indexers: Mapping[Hashable, Any], + indexers: Mapping[Any, Any], dims: Union[list, Mapping[Hashable, int]], missing_dims: str, ) -> Mapping[Hashable, Any]: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f69951580c7..b3b019179a7 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -861,7 +861,7 @@ def attrs(self) -> Dict[Hashable, Any]: return self._attrs @attrs.setter - def attrs(self, value: Mapping[Hashable, Any]) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property @@ -1122,7 +1122,7 @@ def _to_dense(self): def isel( self: VariableType, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, missing_dims: str = "raise", **indexers_kwargs: Any, ) -> VariableType: @@ -1557,7 +1557,7 @@ def stack(self, dimensions=None, **dimensions_kwargs): return result def _unstack_once_full( - self, dims: Mapping[Hashable, int], old_dim: Hashable + self, dims: Mapping[Any, int], old_dim: Hashable ) -> "Variable": """ Unstacks the variable without needing an index. From 78d4dd62d118e0afe490db4fb797cdd5e98d166c Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Wed, 11 Aug 2021 18:36:35 -0700 Subject: [PATCH 3/4] Add in some annotated tests --- properties/test_pandas_roundtrip.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 5fc097f1f5e..e8cef95c029 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -28,7 +28,7 @@ @st.composite -def datasets_1d_vars(draw): +def datasets_1d_vars(draw) -> xr.Dataset: """Generate datasets with only 1D variables Suitable for converting to pandas dataframes. @@ -49,7 +49,7 @@ def datasets_1d_vars(draw): @given(st.data(), an_array) -def test_roundtrip_dataarray(data, arr): +def test_roundtrip_dataarray(data, arr) -> None: names = data.draw( st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( tuple @@ -62,7 +62,7 @@ def test_roundtrip_dataarray(data, arr): @given(datasets_1d_vars()) -def test_roundtrip_dataset(dataset): +def test_roundtrip_dataset(dataset) -> None: df = dataset.to_dataframe() assert isinstance(df, pd.DataFrame) roundtripped = xr.Dataset(df) @@ -70,7 +70,7 @@ def test_roundtrip_dataset(dataset): @given(numeric_series, st.text()) -def test_roundtrip_pandas_series(ser, ix_name): +def test_roundtrip_pandas_series(ser, ix_name) -> None: # Need to name the index, otherwise Xarray calls it 'dim_0'. ser.index.name = ix_name arr = xr.DataArray(ser) @@ -87,7 +87,7 @@ def test_roundtrip_pandas_series(ser, ix_name): @pytest.mark.xfail @given(numeric_homogeneous_dataframe) -def test_roundtrip_pandas_dataframe(df): +def test_roundtrip_pandas_dataframe(df) -> None: # Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'. df.index.name = "rows" df.columns.name = "cols" From 7eb505a438b36179d3fffa45496dadac331de413 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Wed, 18 Aug 2021 22:49:59 -0700 Subject: [PATCH 4/4] whatsnew --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 52114be991b..eda57136b91 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -53,6 +53,8 @@ Internal Changes By `Deepak Cherian `_. - Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`). By `Benoit Bovy `_. +- Fix ``Mapping`` argument typing to allow mypy to pass on ``str`` keys (:pull:`5690`). + By `Maximilian Roos `_. - Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`) By `Jimmy Westling `_.