Skip to content

Commit

Permalink
Fix mypy issues & reenable in tests (#6581)
Browse files Browse the repository at this point in the history
* Run mypy tests (but always pass)

So we can at least see the result

* Fix mypy
  • Loading branch information
max-sixty authored May 8, 2022
1 parent 6fbeb13 commit c60f9b0
Show file tree
Hide file tree
Showing 14 changed files with 38 additions and 32 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci-additional.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,9 @@ jobs:
run: |
python -m pip install mypy
# Temporarily overriding to be true due to https://github.com/pydata/xarray/issues/6551
# python -m mypy --install-types --non-interactive
- name: Run mypy
run: |
python -m mypy --install-types --non-interactive || true
python -m mypy --install-types --non-interactive
min-version-policy:
name: Minimum Version Policy
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None
Delayed = None # type: ignore


DATAARRAY_NAME = "__xarray_dataarray_name__"
Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from dask.utils import SerializableLock
except ImportError:
# no need to worry about serializing the lock
SerializableLock = threading.Lock
SerializableLock = threading.Lock # type: ignore

try:
from dask.distributed import Lock as DistributedLock
except ImportError:
DistributedLock = None
DistributedLock = None # type: ignore


# Locks used by multiple backends.
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/_typed_ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ from .variable import Variable
try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray
DaskArray = np.ndarray # type: ignore

# DatasetOpsMixin etc. are parent classes of Dataset etc.
# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally
Expand Down
27 changes: 14 additions & 13 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Iterable,
Mapping,
Sequence,
overload,
)

import numpy as np
Expand Down Expand Up @@ -1846,24 +1845,26 @@ def where(cond, x, y, keep_attrs=None):
)


@overload
def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray:
...
# These overloads seem not to work — mypy says it can't find a matching overload for
# `DataArray` & `DataArray`, despite that being in the first overload. Would be nice to
# have overloaded functions rather than just `T_Xarray` for everything.

# @overload
# def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray:
# ...

@overload
def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
...

# @overload
# def polyval(coord: T_Xarray, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
# ...

@overload
def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset:
...

# @overload
# def polyval(coord: Dataset, coeffs: T_Xarray, degree_dim: Hashable) -> Dataset:
# ...


def polyval(
coord: T_Xarray, coeffs: T_Xarray, degree_dim: Hashable = "degree"
) -> T_Xarray:
def polyval(coord: T_Xarray, coeffs: T_Xarray, degree_dim="degree") -> T_Xarray:
"""Evaluate a polynomial at specific values
Parameters
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dask_array_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
try:
import dask.array as da
except ImportError:
da = None
da = None # type: ignore


def _validate_pad_output_shape(input_shape, pad_width, output_shape):
Expand Down
7 changes: 5 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import pandas as pd

from ..backends.common import AbstractDataStore, ArrayWriter
from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex
from ..plot.plot import _PlotMethods
Expand Down Expand Up @@ -67,7 +68,7 @@
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None
Delayed = None # type: ignore
try:
from cdms2 import Variable as cdms2_Variable
except ImportError:
Expand Down Expand Up @@ -2875,7 +2876,9 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
isnull = pd.isnull(values)
return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)

def to_netcdf(self, *args, **kwargs) -> bytes | Delayed | None:
def to_netcdf(
self, *args, **kwargs
) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
"""Write DataArray contents to a netCDF file.
All parameters are passed directly to :py:meth:`xarray.Dataset.to_netcdf`.
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import xarray as xr

from ..backends.common import ArrayWriter
from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
from ..plot.dataset_plot import _Dataset_PlotMethods
Expand Down Expand Up @@ -110,7 +111,7 @@
try:
from dask.delayed import Delayed
except ImportError:
Delayed = None
Delayed = None # type: ignore


# list of attributes of pd.DatetimeIndex that are ndarrays of time info
Expand Down Expand Up @@ -1686,7 +1687,7 @@ def to_netcdf(
unlimited_dims: Iterable[Hashable] = None,
compute: bool = True,
invalid_netcdf: bool = False,
) -> bytes | Delayed | None:
) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
"""Write dataset contents to a netCDF file.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import dask.array as dask_array
from dask.base import tokenize
except ImportError:
dask_array = None
dask_array = None # type: ignore


def _dask_or_eager_func(
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from . import dask_array_compat
except ImportError:
dask_array = None
dask_array = None # type: ignore[assignment]
dask_array_compat = None # type: ignore[assignment]


Expand Down
2 changes: 1 addition & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray
DaskArray = np.ndarray # type: ignore


T_Dataset = TypeVar("T_Dataset", bound="Dataset")
Expand Down
7 changes: 5 additions & 2 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import functools
import operator
import pickle
Expand All @@ -22,6 +24,7 @@
unified_dim_sizes,
)
from xarray.core.pycompat import dask_version
from xarray.core.types import T_Xarray

from . import has_dask, raise_if_dask_computes, requires_dask

Expand Down Expand Up @@ -2009,14 +2012,14 @@ def test_where_attrs() -> None:
),
],
)
def test_polyval(use_dask, x, coeffs, expected) -> None:
def test_polyval(use_dask, x: T_Xarray, coeffs: T_Xarray, expected) -> None:
if use_dask:
if not has_dask:
pytest.skip("requires dask")
coeffs = coeffs.chunk({"degree": 2})
x = x.chunk({"x": 2})
with raise_if_dask_computes():
actual = xr.polyval(x, coeffs)
actual = xr.polyval(coord=x, coeffs=coeffs)
xr.testing.assert_allclose(actual, expected)


Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
try:
from dask.array import from_array as dask_from_array
except ImportError:
dask_from_array = lambda x: x
dask_from_array = lambda x: x # type: ignore

try:
import pint
Expand Down
2 changes: 1 addition & 1 deletion xarray/util/generate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def inplace():
try:
from dask.array import Array as DaskArray
except ImportError:
DaskArray = np.ndarray
DaskArray = np.ndarray # type: ignore
# DatasetOpsMixin etc. are parent classes of Dataset etc.
# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally
Expand Down

0 comments on commit c60f9b0

Please sign in to comment.