diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 52114be991b..eda57136b91 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -53,6 +53,8 @@ Internal Changes By `Deepak Cherian `_. - Explicit indexes refactor: decouple ``xarray.Index``` from ``xarray.Variable`` (:pull:`5636`). By `Benoit Bovy `_. +- Fix ``Mapping`` argument typing to allow mypy to pass on ``str`` keys (:pull:`5690`). + By `Maximilian Roos `_. - Improve the performance of reprs for large datasets or dataarrays. (:pull:`5661`) By `Jimmy Westling `_. diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 5fc097f1f5e..e8cef95c029 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -28,7 +28,7 @@ @st.composite -def datasets_1d_vars(draw): +def datasets_1d_vars(draw) -> xr.Dataset: """Generate datasets with only 1D variables Suitable for converting to pandas dataframes. @@ -49,7 +49,7 @@ def datasets_1d_vars(draw): @given(st.data(), an_array) -def test_roundtrip_dataarray(data, arr): +def test_roundtrip_dataarray(data, arr) -> None: names = data.draw( st.lists(st.text(), min_size=arr.ndim, max_size=arr.ndim, unique=True).map( tuple @@ -62,7 +62,7 @@ def test_roundtrip_dataarray(data, arr): @given(datasets_1d_vars()) -def test_roundtrip_dataset(dataset): +def test_roundtrip_dataset(dataset) -> None: df = dataset.to_dataframe() assert isinstance(df, pd.DataFrame) roundtripped = xr.Dataset(df) @@ -70,7 +70,7 @@ def test_roundtrip_dataset(dataset): @given(numeric_series, st.text()) -def test_roundtrip_pandas_series(ser, ix_name): +def test_roundtrip_pandas_series(ser, ix_name) -> None: # Need to name the index, otherwise Xarray calls it 'dim_0'. ser.index.name = ix_name arr = xr.DataArray(ser) @@ -87,7 +87,7 @@ def test_roundtrip_pandas_series(ser, ix_name): @pytest.mark.xfail @given(numeric_homogeneous_dataframe) -def test_roundtrip_pandas_dataframe(df): +def test_roundtrip_pandas_dataframe(df) -> None: # Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'. df.index.name = "rows" df.columns.name = "cols" diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index e3c35d6e4b6..0f0a256b77a 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -114,7 +114,7 @@ def _apply_str_ufunc( obj: Any, dtype: Union[str, np.dtype, Type] = None, output_core_dims: Union[list, tuple] = ((),), - output_sizes: Mapping[Hashable, int] = None, + output_sizes: Mapping[Any, int] = None, func_args: Tuple = (), func_kwargs: Mapping = {}, ) -> Any: @@ -227,7 +227,7 @@ def _apply( func: Callable, dtype: Union[str, np.dtype, Type] = None, output_core_dims: Union[list, tuple] = ((),), - output_sizes: Mapping[Hashable, int] = None, + output_sizes: Mapping[Any, int] = None, func_args: Tuple = (), func_kwargs: Mapping = {}, ) -> Any: diff --git a/xarray/core/common.py b/xarray/core/common.py index ab822f576d3..d3001532aa0 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -818,7 +818,7 @@ def weighted( def rolling( self, - dim: Mapping[Hashable, int] = None, + dim: Mapping[Any, int] = None, min_periods: int = None, center: Union[bool, Mapping[Hashable, bool]] = False, **window_kwargs: int, @@ -892,7 +892,7 @@ def rolling( def rolling_exp( self, - window: Mapping[Hashable, int] = None, + window: Mapping[Any, int] = None, window_type: str = "span", **window_kwargs, ): @@ -933,7 +933,7 @@ def rolling_exp( def coarsen( self, - dim: Mapping[Hashable, int] = None, + dim: Mapping[Any, int] = None, boundary: str = "exact", side: Union[str, Mapping[Hashable, str]] = "left", coord_func: str = "mean", @@ -1009,7 +1009,7 @@ def coarsen( def resample( self, - indexer: Mapping[Hashable, str] = None, + indexer: Mapping[Any, str] = None, skipna=None, closed: str = None, label: str = None, diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cd9e22d90db..5bfb14793bb 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -400,7 +400,7 @@ def apply_dict_of_variables_vfunc( def _fast_dataset( - variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable] + variables: Dict[Hashable, Variable], coord_variables: Mapping[Any, Variable] ) -> "Dataset": """Create a dataset as quickly as possible. diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 767b76d0d12..56afdb9774a 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -158,7 +158,7 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: return pd.MultiIndex(level_list, code_list, names=names) - def update(self, other: Mapping[Hashable, Any]) -> None: + def update(self, other: Mapping[Any, Any]) -> None: other_vars = getattr(other, "variables", other) coords, indexes = merge_coords( [self.variables, other_vars], priority_arg=1, indexes=self.xindexes @@ -270,7 +270,7 @@ def to_dataset(self) -> "Dataset": return self._data._copy_listed(names) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: from .dataset import calculate_dimensions @@ -333,7 +333,7 @@ def __getitem__(self, key: Hashable) -> "DataArray": return self._data._getitem_coord(key) def _update_coords( - self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, Index] + self, coords: Dict[Hashable, Variable], indexes: Mapping[Any, Index] ) -> None: from .dataset import calculate_dimensions @@ -376,7 +376,7 @@ def _ipython_key_completions_(self): def assert_coordinate_consistent( - obj: Union["DataArray", "Dataset"], coords: Mapping[Hashable, Variable] + obj: Union["DataArray", "Dataset"], coords: Mapping[Any, Variable] ) -> None: """Make sure the dimension coordinate of obj is consistent with coords. @@ -394,7 +394,7 @@ def assert_coordinate_consistent( def remap_label_indexers( obj: Union["DataArray", "Dataset"], - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance=None, **indexers_kwargs: Any, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 900af885319..e86bae08a3f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -462,7 +462,7 @@ def _replace_maybe_drop_dims( ) return self._replace(variable, coords, name, indexes=indexes) - def _overwrite_indexes(self, indexes: Mapping[Hashable, Any]) -> "DataArray": + def _overwrite_indexes(self, indexes: Mapping[Any, Any]) -> "DataArray": if not len(indexes): return self coords = self._coords.copy() @@ -792,7 +792,7 @@ def attrs(self) -> Dict[Hashable, Any]: return self.variable.attrs @attrs.setter - def attrs(self, value: Mapping[Hashable, Any]) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: # Disable type checking to work around mypy bug - see mypy#4167 self.variable.attrs = value # type: ignore[assignment] @@ -803,7 +803,7 @@ def encoding(self) -> Dict[Hashable, Any]: return self.variable.encoding @encoding.setter - def encoding(self, value: Mapping[Hashable, Any]) -> None: + def encoding(self, value: Mapping[Any, Any]) -> None: self.variable.encoding = value @property @@ -1110,7 +1110,7 @@ def chunk( def isel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, drop: bool = False, missing_dims: str = "raise", **indexers_kwargs: Any, @@ -1193,7 +1193,7 @@ def isel( def sel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance=None, drop: bool = False, @@ -1498,7 +1498,7 @@ def reindex_like( def reindex( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance=None, copy: bool = True, @@ -1591,7 +1591,7 @@ def reindex( def interp( self, - coords: Mapping[Hashable, Any] = None, + coords: Mapping[Any, Any] = None, method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, @@ -1815,7 +1815,7 @@ def rename( return self._replace(name=new_name_or_name_dict) def swap_dims( - self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + self, dims_dict: Mapping[Any, Hashable] = None, **dims_kwargs ) -> "DataArray": """Returns a new DataArray with swapped dimensions. @@ -2333,7 +2333,7 @@ def drop( def drop_sel( self, - labels: Mapping[Hashable, Any] = None, + labels: Mapping[Any, Any] = None, *, errors: str = "raise", **labels_kwargs, @@ -3163,7 +3163,7 @@ def diff(self, dim: Hashable, n: int = 1, label: Hashable = "upper") -> "DataArr def shift( self, - shifts: Mapping[Hashable, int] = None, + shifts: Mapping[Any, int] = None, fill_value: Any = dtypes.NA, **shifts_kwargs: int, ) -> "DataArray": @@ -3210,7 +3210,7 @@ def shift( def roll( self, - shifts: Mapping[Hashable, int] = None, + shifts: Mapping[Any, int] = None, roll_coords: bool = None, **shifts_kwargs: int, ) -> "DataArray": @@ -4433,7 +4433,7 @@ def argmax( def query( self, - queries: Mapping[Hashable, Any] = None, + queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, missing_dims: str = "raise", diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4bfc1ccbdf1..90c395ed39b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -185,7 +185,7 @@ def _get_virtual_variable( return ref_name, var_name, virtual_var -def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashable, int]: +def calculate_dimensions(variables: Mapping[Any, Variable]) -> Dict[Hashable, int]: """Calculate the dimensions corresponding to a set of variables. Returns dictionary mapping from dimension names to sizes. Raises ValueError @@ -213,7 +213,7 @@ def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashabl def merge_indexes( indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]], - variables: Mapping[Hashable, Variable], + variables: Mapping[Any, Variable], coord_names: Set[Hashable], append: bool = False, ) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]: @@ -297,9 +297,9 @@ def merge_indexes( def split_indexes( dims_or_levels: Union[Hashable, Sequence[Hashable]], - variables: Mapping[Hashable, Variable], + variables: Mapping[Any, Variable], coord_names: Set[Hashable], - level_coords: Mapping[Hashable, Hashable], + level_coords: Mapping[Any, Hashable], drop: bool = False, ) -> Tuple[Dict[Hashable, Variable], Set[Hashable]]: """Extract (multi-)indexes (levels) as variables. @@ -559,7 +559,7 @@ class _LocIndexer: def __init__(self, dataset: "Dataset"): self.dataset = dataset - def __getitem__(self, key: Mapping[Hashable, Any]) -> "Dataset": + def __getitem__(self, key: Mapping[Any, Any]) -> "Dataset": if not utils.is_dict_like(key): raise TypeError("can only lookup dictionaries from Dataset.loc") return self.dataset.sel(key) @@ -730,9 +730,9 @@ def __init__( self, # could make a VariableArgs to use more generally, and refine these # categories - data_vars: Mapping[Hashable, Any] = None, - coords: Mapping[Hashable, Any] = None, - attrs: Mapping[Hashable, Any] = None, + data_vars: Mapping[Any, Any] = None, + coords: Mapping[Any, Any] = None, + attrs: Mapping[Any, Any] = None, ): # TODO(shoyer): expose indexes as a public argument in __init__ @@ -793,7 +793,7 @@ def attrs(self) -> Dict[Hashable, Any]: return self._attrs @attrs.setter - def attrs(self, value: Mapping[Hashable, Any]) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property @@ -2164,7 +2164,7 @@ def chunk( return self._replace(variables) def _validate_indexers( - self, indexers: Mapping[Hashable, Any], missing_dims: str = "raise" + self, indexers: Mapping[Any, Any], missing_dims: str = "raise" ) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]: """Here we make sure + indexer has a valid keys @@ -2208,7 +2208,7 @@ def _validate_indexers( yield k, v def _validate_interp_indexers( - self, indexers: Mapping[Hashable, Any] + self, indexers: Mapping[Any, Any] ) -> Iterator[Tuple[Hashable, Variable]]: """Variant of _validate_indexers to be used for interpolation""" for k, v in self._validate_indexers(indexers): @@ -2269,7 +2269,7 @@ def _get_indexers_coords_and_indexes(self, indexers): def isel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, drop: bool = False, missing_dims: str = "raise", **indexers_kwargs: Any, @@ -2361,7 +2361,7 @@ def isel( def _isel_fancy( self, - indexers: Mapping[Hashable, Any], + indexers: Mapping[Any, Any], *, drop: bool, missing_dims: str = "raise", @@ -2403,7 +2403,7 @@ def _isel_fancy( def sel( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance: Number = None, drop: bool = False, @@ -2711,7 +2711,7 @@ def reindex_like( def reindex( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance: Number = None, copy: bool = True, @@ -2921,7 +2921,7 @@ def reindex( def _reindex( self, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, method: str = None, tolerance: Number = None, copy: bool = True, @@ -2955,7 +2955,7 @@ def _reindex( def interp( self, - coords: Mapping[Hashable, Any] = None, + coords: Mapping[Any, Any] = None, method: str = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] = None, @@ -3325,7 +3325,7 @@ def _rename_all(self, name_dict, dims_dict): def rename( self, - name_dict: Mapping[Hashable, Hashable] = None, + name_dict: Mapping[Any, Hashable] = None, **names: Hashable, ) -> "Dataset": """Returns a new object with renamed variables and dimensions. @@ -3366,7 +3366,7 @@ def rename( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename_dims( - self, dims_dict: Mapping[Hashable, Hashable] = None, **dims: Hashable + self, dims_dict: Mapping[Any, Hashable] = None, **dims: Hashable ) -> "Dataset": """Returns a new object with renamed dimensions only. @@ -3411,7 +3411,7 @@ def rename_dims( return self._replace(variables, coord_names, dims=sizes, indexes=indexes) def rename_vars( - self, name_dict: Mapping[Hashable, Hashable] = None, **names: Hashable + self, name_dict: Mapping[Any, Hashable] = None, **names: Hashable ) -> "Dataset": """Returns a new object with renamed variables including coordinates @@ -3449,7 +3449,7 @@ def rename_vars( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def swap_dims( - self, dims_dict: Mapping[Hashable, Hashable] = None, **dims_kwargs + self, dims_dict: Mapping[Any, Hashable] = None, **dims_kwargs ) -> "Dataset": """Returns a new object with swapped dimensions. @@ -5146,7 +5146,7 @@ def apply( return self.map(func, keep_attrs, args, **kwargs) def assign( - self, variables: Mapping[Hashable, Any] = None, **variables_kwargs: Hashable + self, variables: Mapping[Any, Any] = None, **variables_kwargs: Hashable ) -> "Dataset": """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. @@ -5322,7 +5322,7 @@ def to_pandas(self) -> Union[pd.Series, pd.DataFrame]: "Please use Dataset.to_dataframe() instead." % len(self.dims) ) - def _to_dataframe(self, ordered_dims: Mapping[Hashable, int]): + def _to_dataframe(self, ordered_dims: Mapping[Any, int]): columns = [k for k in self.variables if k not in self.dims] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) @@ -7391,7 +7391,7 @@ def argmax(self, dim=None, **kwargs): def query( self, - queries: Mapping[Hashable, Any] = None, + queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, missing_dims: str = "raise", diff --git a/xarray/core/merge.py b/xarray/core/merge.py index b8b32bdaa01..eaa5b62b2a9 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -170,7 +170,7 @@ def _assert_compat_valid(compat): def merge_collected( grouped: Dict[Hashable, List[MergeElement]], - prioritized: Mapping[Hashable, MergeElement] = None, + prioritized: Mapping[Any, MergeElement] = None, compat: str = "minimal", combine_attrs="override", ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: @@ -319,7 +319,7 @@ def collect_from_coordinates( def merge_coordinates_without_align( objects: "List[Coordinates]", - prioritized: Mapping[Hashable, MergeElement] = None, + prioritized: Mapping[Any, MergeElement] = None, exclude_dims: AbstractSet = frozenset(), combine_attrs: str = "override", ) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, Index]]: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index a139d2ef10a..57b7035c940 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -816,7 +816,7 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: def drop_dims_from_indexers( - indexers: Mapping[Hashable, Any], + indexers: Mapping[Any, Any], dims: Union[list, Mapping[Hashable, int]], missing_dims: str, ) -> Mapping[Hashable, Any]: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6b971389de7..2798a4ab956 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -876,7 +876,7 @@ def attrs(self) -> Dict[Hashable, Any]: return self._attrs @attrs.setter - def attrs(self, value: Mapping[Hashable, Any]) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) @property @@ -1137,7 +1137,7 @@ def _to_dense(self): def isel( self: VariableType, - indexers: Mapping[Hashable, Any] = None, + indexers: Mapping[Any, Any] = None, missing_dims: str = "raise", **indexers_kwargs: Any, ) -> VariableType: @@ -1572,7 +1572,7 @@ def stack(self, dimensions=None, **dimensions_kwargs): return result def _unstack_once_full( - self, dims: Mapping[Hashable, int], old_dim: Hashable + self, dims: Mapping[Any, int], old_dim: Hashable ) -> "Variable": """ Unstacks the variable without needing an index. diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9b8b7c748f1..1256a44ad81 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6802,3 +6802,12 @@ def test_from_pint_wrapping_dask(self): result = ds.as_numpy() expected = xr.Dataset({"a": ("x", arr)}, coords={"lat": ("x", arr * 2)}) assert_identical(result, expected) + + +def test_string_keys_typing() -> None: + """Tests that string keys to `variables` are permitted by mypy""" + + da = xr.DataArray(np.arange(10), dims=["x"]) + ds = xr.Dataset(dict(x=da)) + mapping = {"y": da} + ds.assign(variables=mapping)