Skip to content

Commit

Permalink
{full,zeros,ones}_like typing (#6611)
Browse files Browse the repository at this point in the history
* type {full,zeros,ones}_like

* fix modern numpy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* python3.8 support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typo

* apply patch from max-sixty

* add link to numpy.typing.DTypeLike

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
headtr1ck and pre-commit-ci[bot] authored May 16, 2022
1 parent 8de7061 commit e712270
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 40 deletions.
122 changes: 104 additions & 18 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
Iterator,
Mapping,
TypeVar,
Union,
overload,
)

import numpy as np
import pandas as pd

from . import dtypes, duck_array_ops, formatting, formatting_html, ops
from .npcompat import DTypeLike
from .npcompat import DTypeLike, DTypeLikeSave
from .options import OPTIONS, _get_keep_attrs
from .pycompat import is_duck_dask_array
from .rolling_exp import RollingExp
Expand Down Expand Up @@ -1577,26 +1578,45 @@ def __getitem__(self, value):
raise NotImplementedError()


DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]]


@overload
def full_like(
other: Dataset,
fill_value,
dtype: DTypeLike | Mapping[Any, DTypeLike] = None,
) -> Dataset:
def full_like(other: DataArray, fill_value: Any, dtype: DTypeLikeSave) -> DataArray:
...


@overload
def full_like(other: Dataset, fill_value: Any, dtype: DTypeMaybeMapping) -> Dataset:
...


@overload
def full_like(other: DataArray, fill_value, dtype: DTypeLike = None) -> DataArray:
def full_like(other: Variable, fill_value: Any, dtype: DTypeLikeSave) -> Variable:
...


@overload
def full_like(other: Variable, fill_value, dtype: DTypeLike = None) -> Variable:
def full_like(
other: Dataset | DataArray, fill_value: Any, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray:
...


def full_like(other, fill_value, dtype=None):
@overload
def full_like(
other: Dataset | DataArray | Variable,
fill_value: Any,
dtype: DTypeMaybeMapping = None,
) -> Dataset | DataArray | Variable:
...


def full_like(
other: Dataset | DataArray | Variable,
fill_value: Any,
dtype: DTypeMaybeMapping = None,
) -> Dataset | DataArray | Variable:
"""Return a new object with the same shape and type as a given object.
Parameters
Expand Down Expand Up @@ -1711,26 +1731,26 @@ def full_like(other, fill_value, dtype=None):
f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead."
)

if not isinstance(other, Dataset) and isinstance(dtype, Mapping):
raise ValueError(
"'dtype' cannot be dict-like when passing a DataArray or Variable"
)

if isinstance(other, Dataset):
if not isinstance(fill_value, dict):
fill_value = {k: fill_value for k in other.data_vars.keys()}

dtype_: Mapping[Any, DTypeLikeSave]
if not isinstance(dtype, Mapping):
dtype_ = {k: dtype for k in other.data_vars.keys()}
else:
dtype_ = dtype

data_vars = {
k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype_.get(k, None))
k: _full_like_variable(
v.variable, fill_value.get(k, dtypes.NA), dtype_.get(k, None)
)
for k, v in other.data_vars.items()
}
return Dataset(data_vars, coords=other.coords, attrs=other.attrs)
elif isinstance(other, DataArray):
if isinstance(dtype, Mapping):
raise ValueError("'dtype' cannot be dict-like when passing a DataArray")
return DataArray(
_full_like_variable(other.variable, fill_value, dtype),
dims=other.dims,
Expand All @@ -1739,12 +1759,16 @@ def full_like(other, fill_value, dtype=None):
name=other.name,
)
elif isinstance(other, Variable):
if isinstance(dtype, Mapping):
raise ValueError("'dtype' cannot be dict-like when passing a Variable")
return _full_like_variable(other, fill_value, dtype)
else:
raise TypeError("Expected DataArray, Dataset, or Variable")


def _full_like_variable(other, fill_value, dtype: DTypeLike = None):
def _full_like_variable(
other: Variable, fill_value: Any, dtype: DTypeLike = None
) -> Variable:
"""Inner function of full_like, where other must be a variable"""
from .variable import Variable

Expand All @@ -1765,7 +1789,38 @@ def _full_like_variable(other, fill_value, dtype: DTypeLike = None):
return Variable(dims=other.dims, data=data, attrs=other.attrs)


def zeros_like(other, dtype: DTypeLike = None):
@overload
def zeros_like(other: DataArray, dtype: DTypeLikeSave) -> DataArray:
...


@overload
def zeros_like(other: Dataset, dtype: DTypeMaybeMapping) -> Dataset:
...


@overload
def zeros_like(other: Variable, dtype: DTypeLikeSave) -> Variable:
...


@overload
def zeros_like(
other: Dataset | DataArray, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray:
...


@overload
def zeros_like(
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray | Variable:
...


def zeros_like(
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray | Variable:
"""Return a new object of zeros with the same shape and
type as a given dataarray or dataset.
Expand Down Expand Up @@ -1821,7 +1876,38 @@ def zeros_like(other, dtype: DTypeLike = None):
return full_like(other, 0, dtype)


def ones_like(other, dtype: DTypeLike = None):
@overload
def ones_like(other: DataArray, dtype: DTypeLikeSave) -> DataArray:
...


@overload
def ones_like(other: Dataset, dtype: DTypeMaybeMapping) -> Dataset:
...


@overload
def ones_like(other: Variable, dtype: DTypeLikeSave) -> Variable:
...


@overload
def ones_like(
other: Dataset | DataArray, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray:
...


@overload
def ones_like(
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray | Variable:
...


def ones_like(
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray | Variable:
"""Return a new object of ones with the same shape and
type as a given dataarray or dataset.
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ def polyval(
coeffs = coeffs.reindex(
{degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False
)
coord = _ensure_numeric(coord) # type: ignore # https://github.com/python/mypy/issues/1533 ?
coord = _ensure_numeric(coord)

# using Horner's method
# https://en.wikipedia.org/wiki/Horner%27s_method
Expand All @@ -1917,7 +1917,7 @@ def polyval(
return res


def _ensure_numeric(data: T_Xarray) -> T_Xarray:
def _ensure_numeric(data: Dataset | DataArray) -> Dataset | DataArray:
"""Converts all datetime64 variables to float64
Parameters
Expand Down
22 changes: 15 additions & 7 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import datetime as dt
import warnings
from functools import partial
from numbers import Number
from typing import Any, Callable, Dict, Hashable, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence

import numpy as np
import pandas as pd
Expand All @@ -17,8 +19,14 @@
from .utils import OrderedSet, is_scalar
from .variable import Variable, broadcast_variables

if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset


def _get_nan_block_lengths(obj, dim: Hashable, index: Variable):
def _get_nan_block_lengths(
obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable
):
"""
Return an object where each NaN element in 'obj' is replaced by the
length of the gap the element is in.
Expand Down Expand Up @@ -48,8 +56,8 @@ def _get_nan_block_lengths(obj, dim: Hashable, index: Variable):
class BaseInterpolator:
"""Generic interpolator class for normalizing interpolation methods"""

cons_kwargs: Dict[str, Any]
call_kwargs: Dict[str, Any]
cons_kwargs: dict[str, Any]
call_kwargs: dict[str, Any]
f: Callable
method: str

Expand Down Expand Up @@ -213,7 +221,7 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs):


def get_clean_interp_index(
arr, dim: Hashable, use_coordinate: Union[str, bool] = True, strict: bool = True
arr, dim: Hashable, use_coordinate: str | bool = True, strict: bool = True
):
"""Return index to use for x values in interpolation or curve fitting.
Expand Down Expand Up @@ -300,10 +308,10 @@ def get_clean_interp_index(
def interp_na(
self,
dim: Hashable = None,
use_coordinate: Union[bool, str] = True,
use_coordinate: bool | str = True,
method: str = "linear",
limit: int = None,
max_gap: Union[int, float, str, pd.Timedelta, np.timedelta64, dt.timedelta] = None,
max_gap: int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta = None,
keep_attrs: bool = None,
**kwargs,
):
Expand Down
54 changes: 52 additions & 2 deletions xarray/core/npcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,49 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

import numpy as np
from packaging.version import Version

# Type annotations stubs
try:
from numpy.typing import ArrayLike, DTypeLike
from numpy.typing._dtype_like import _DTypeLikeNested, _ShapeLike, _SupportsDType

# Xarray requires a Mapping[Hashable, dtype] in many places which
# conflics with numpys own DTypeLike (with dtypes for fields).
# https://numpy.org/devdocs/reference/typing.html#numpy.typing.DTypeLike
# This is a copy of this DTypeLike that allows only non-Mapping dtypes.
DTypeLikeSave = Union[
np.dtype,
# default data type (float64)
None,
# array-scalar types and generic types
Type[Any],
# character codes, type strings or comma-separated fields, e.g., 'float64'
str,
# (flexible_dtype, itemsize)
Tuple[_DTypeLikeNested, int],
# (fixed_dtype, shape)
Tuple[_DTypeLikeNested, _ShapeLike],
# (base_dtype, new_dtype)
Tuple[_DTypeLikeNested, _DTypeLikeNested],
# because numpy does the same?
List[Any],
# anything with a dtype attribute
_SupportsDType[np.dtype],
]
except ImportError:
# fall back for numpy < 1.20, ArrayLike adapted from numpy.typing._array_like
from typing import Protocol
Expand All @@ -46,8 +81,14 @@ class _SupportsArray(Protocol):
def __array__(self) -> np.ndarray:
...

class _SupportsDTypeFallback(Protocol):
@property
def dtype(self) -> np.dtype:
...

else:
_SupportsArray = Any
_SupportsDTypeFallback = Any

_T = TypeVar("_T")
_NestedSequence = Union[
Expand All @@ -72,7 +113,16 @@ def __array__(self) -> np.ndarray:
# with the same name (ArrayLike and DTypeLike from the try block)
ArrayLike = _ArrayLikeFallback # type: ignore
# fall back for numpy < 1.20
DTypeLike = Union[np.dtype, str] # type: ignore[misc]
DTypeLikeSave = Union[ # type: ignore[misc]
np.dtype,
str,
None,
Type[Any],
Tuple[Any, Any],
List[Any],
_SupportsDTypeFallback,
]
DTypeLike = DTypeLikeSave # type: ignore[misc]


if Version(np.__version__) >= Version("1.20.0"):
Expand Down
Loading

0 comments on commit e712270

Please sign in to comment.