Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate TypeVars in a single place #5569

Merged
merged 14 commits into from
Aug 21, 2021
21 changes: 10 additions & 11 deletions xarray/core/_typed_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,23 @@ from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy
from .npcompat import ArrayLike
from .types import (
DaCompatible,
DsCompatible,
GroupByIncompatible,
ScalarOrArray,
T_DataArray,
T_Dataset,
T_Variable,
VarCompatible,
)
from .variable import Variable

try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray

# DatasetOpsMixin etc. are parent classes of Dataset etc.
T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin")
T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin")
T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")

ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray]
DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray]
DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray]
VarCompatible = Union[Variable, ScalarOrArray]
GroupByIncompatible = Union[Variable, GroupBy]

class DatasetOpsMixin:
__slots__ = ()
def _binary_op(self, other, f, reflexive=...): ...
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from contextlib import suppress
from html import escape
Expand Down Expand Up @@ -36,10 +38,10 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import T_DataWithCoords, T_Xarray
from .variable import Variable
from .weighted import Weighted

T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")

C = TypeVar("C")
T = TypeVar("T")
Expand Down Expand Up @@ -795,9 +797,7 @@ def groupby_bins(
},
)

def weighted(
self: T_DataWithCoords, weights: "DataArray"
) -> "Weighted[T_DataWithCoords]":
def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_Xarray]:
"""
Weighted operations.

Expand Down
15 changes: 6 additions & 9 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -36,11 +35,9 @@
from .variable import Variable

if TYPE_CHECKING:
from .coordinates import Coordinates # noqa
from .dataarray import DataArray
from .coordinates import Coordinates
from .dataset import Dataset

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
from .types import T_Xarray

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
Expand Down Expand Up @@ -199,7 +196,7 @@ def result_name(objects: list) -> Any:
return name


def _get_coords_list(args) -> List["Coordinates"]:
def _get_coords_list(args) -> List[Coordinates]:
coords_list = []
for arg in args:
try:
Expand Down Expand Up @@ -400,8 +397,8 @@ def apply_dict_of_variables_vfunc(


def _fast_dataset(
variables: Dict[Hashable, Variable], coord_variables: Mapping[Any, Variable]
) -> "Dataset":
variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable]
) -> Dataset:
"""Create a dataset as quickly as possible.

Beware: the `variables` dict is modified INPLACE.
Expand Down Expand Up @@ -1729,7 +1726,7 @@ def _calc_idxminmax(
return res


def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]:
def unify_chunks(*objects: T_Xarray) -> Tuple[T_Xarray, ...]:
"""
Given any number of Dataset and/or DataArray objects, returns
new objects with unified chunk size along all chunked dimensions.
Expand Down
11 changes: 6 additions & 5 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import datetime
import warnings
from typing import (
Expand All @@ -12,7 +14,6 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)
Expand Down Expand Up @@ -70,8 +71,6 @@
assert_unique_multiindex_level_names,
)

T_DataArray = TypeVar("T_DataArray", bound="DataArray")
T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset)
if TYPE_CHECKING:
try:
from dask.delayed import Delayed
Expand All @@ -86,6 +85,8 @@
except ImportError:
iris_Cube = None

from .types import T_DataArray, T_Xarray


def _infer_coords_and_dims(
shape, coords, dims
Expand Down Expand Up @@ -3698,11 +3699,11 @@ def unify_chunks(self) -> "DataArray":

def map_blocks(
self,
func: Callable[..., T_DSorDA],
func: Callable[..., T_Xarray],
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] = None,
template: Union["DataArray", "Dataset"] = None,
) -> T_DSorDA:
) -> T_Xarray:
"""
Apply a function to each block of this DataArray.

Expand Down
8 changes: 3 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Sequence,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
Expand Down Expand Up @@ -109,8 +108,7 @@
from ..backends import AbstractDataStore, ZarrStore
from .dataarray import DataArray
from .merge import CoercibleMapping

T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset")
from .types import T_Xarray

try:
from dask.delayed import Delayed
Expand Down Expand Up @@ -6630,11 +6628,11 @@ def unify_chunks(self) -> "Dataset":

def map_blocks(
self,
func: "Callable[..., T_DSorDA]",
func: "Callable[..., T_Xarray]",
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] = None,
template: Union["DataArray", "Dataset"] = None,
) -> "T_DSorDA":
) -> "T_Xarray":
"""
Apply a function to each block of this Dataset.

Expand Down
15 changes: 9 additions & 6 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import collections
import itertools
import operator
from typing import (
TYPE_CHECKING,
Any,
Callable,
DefaultDict,
Expand All @@ -12,7 +15,6 @@
Mapping,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -32,7 +34,8 @@
pass


T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
if TYPE_CHECKING:
from .types import T_Xarray


def unzip(iterable):
Expand Down Expand Up @@ -122,8 +125,8 @@ def make_meta(obj):


def infer_template(
func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs
) -> T_DSorDA:
func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], *args, **kwargs
) -> T_Xarray:
"""Infer return object by running the function on meta objects."""
meta_args = [make_meta(arg) for arg in (obj,) + args]

Expand Down Expand Up @@ -162,12 +165,12 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping


def map_blocks(
func: Callable[..., T_DSorDA],
func: Callable[..., T_Xarray],
obj: Union[DataArray, Dataset],
args: Sequence[Any] = (),
kwargs: Mapping[str, Any] = None,
template: Union[DataArray, Dataset] = None,
) -> T_DSorDA:
) -> T_Xarray:
"""Apply a function to each block of a DataArray or Dataset.

.. warning::
Expand Down
21 changes: 9 additions & 12 deletions xarray/core/rolling_exp.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations

from distutils.version import LooseVersion
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union
from typing import Generic, Hashable, Mapping, Union

import numpy as np

from .options import _get_keep_attrs
from .pdcompat import count_not_none
from .pycompat import is_duck_dask_array

if TYPE_CHECKING:
from .dataarray import DataArray # noqa: F401
from .dataset import Dataset # noqa: F401

T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")
from .types import T_Xarray


def _get_alpha(com=None, span=None, halflife=None, alpha=None):
Expand Down Expand Up @@ -79,7 +76,7 @@ def _get_center_of_mass(comass, span, halflife, alpha):
return float(comass)


class RollingExp(Generic[T_DSorDA]):
class RollingExp(Generic[T_Xarray]):
"""
Exponentially-weighted moving window object.
Similar to EWM in pandas
Expand All @@ -103,16 +100,16 @@ class RollingExp(Generic[T_DSorDA]):

def __init__(
self,
obj: T_DSorDA,
obj: T_Xarray,
windows: Mapping[Hashable, Union[int, float]],
window_type: str = "span",
):
self.obj: T_DSorDA = obj
self.obj: T_Xarray = obj
dim, window = next(iter(windows.items()))
self.dim = dim
self.alpha = _get_alpha(**{window_type: window})

def mean(self, keep_attrs: bool = None) -> T_DSorDA:
def mean(self, keep_attrs: bool = None) -> T_Xarray:
"""
Exponentially weighted moving average.

Expand All @@ -139,7 +136,7 @@ def mean(self, keep_attrs: bool = None) -> T_DSorDA:
move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs
)

def sum(self, keep_attrs: bool = None) -> T_DSorDA:
def sum(self, keep_attrs: bool = None) -> T_Xarray:
"""
Exponentially weighted moving sum.

Expand Down
31 changes: 31 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar, Union

import numpy as np

if TYPE_CHECKING:
from .common import DataWithCoords
from .dataarray import DataArray
from .dataset import Dataset
from .groupby import DataArrayGroupBy, GroupBy
from .npcompat import ArrayLike
from .variable import Variable

try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray

T_Dataset = TypeVar("T_Dataset", bound="Dataset")
T_DataArray = TypeVar("T_DataArray", bound="DataArray")
T_Variable = TypeVar("T_Variable", bound="Variable")
# Maybe we rename this to T_Data or something less Fortran-y?
T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset")
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")

ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"]
DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"]
VarCompatible = Union["Variable", "ScalarOrArray"]
GroupByIncompatible = Union["Variable", "GroupBy"]
Loading