From 3b2492a8433df14ab6129160a7a61d1c40208fec Mon Sep 17 00:00:00 2001 From: AS Date: Sat, 27 Nov 2021 15:55:29 +0100 Subject: [PATCH 01/25] initial --- .pre-commit-config.yaml | 18 +++++ climpred/classes.py | 101 +++++++++++++++++++------- climpred/versioning/print_versions.py | 5 +- docs/source/contributing.rst | 3 + setup.cfg | 40 ++++++++++ 5 files changed, 135 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4991e3525..438f2155c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,3 +28,21 @@ repos: rev: v5.6.4 hooks: - id: isort + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.910-1 + hooks: + - id: mypy + # `asv_bench` are copied from setup.cfg. + # `_typed_ops.py` is added since otherwise mypy will complain (but notably only in pre-commit) + exclude: "asv_bench" + additional_dependencies: [ + # Type stubs + types-python-dateutil, + types-pkg_resources, + types-PyYAML, + types-pytz, + typing-extensions==3.10.0.0, + # Dependencies that are typed + numpy, + ] diff --git a/climpred/classes.py b/climpred/classes.py index 0b9499d13..5912c3795 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -1,11 +1,32 @@ import warnings from copy import deepcopy +from typing import ( # TYPE_CHECKING,; Collection,; DefaultDict,; MutableMapping, + Any, + Callable, + Dict, + Hashable, + Iterable, + Iterator, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, +) import cf_xarray import numpy as np import xarray as xr from dask import is_dask_collection from IPython.display import display_html +from mypy_extensions import ( # (Arg, DefaultArg, NamedArg, DefaultNamedArg, + KwArg, + VarArg, +) +from xarray.core.coordinates import DatasetCoordinates +from xarray.core.dataset import DataVariables from xarray.core.formatting_html import dataset_repr from xarray.core.options import OPTIONS as XR_OPTIONS from xarray.core.utils import Frozen @@ -65,7 +86,7 @@ ) -def _display_metadata(self): +def _display_metadata(self) -> str: """ This is called in the following case: @@ -117,7 +138,7 @@ def _display_metadata(self): return summary -def _display_metadata_html(self): +def _display_metadata_html(self) -> str: header = f"

climpred.{type(self).__name__}

" display_html(header, raw=True) init_repr_str = dataset_repr(self._datasets["initialized"]) @@ -155,7 +176,7 @@ class PredictionEnsemble: """ @is_xarray(1) - def __init__(self, xobj): + def __init__(self, xobj: Union[xr.DataArray, xr.Dataset]): if isinstance(xobj, xr.DataArray): # makes applying prediction functions easier, etc. xobj = xobj.to_dataset() @@ -185,7 +206,7 @@ def __init__(self, xobj): self._warn_if_chunked_along_init_member_time() @property - def coords(self): + def coords(self) -> DatasetCoordinates: """Dictionary of xarray.DataArray objects corresponding to coordinate variables available in all PredictionEnsemble._datasets. """ @@ -207,7 +228,7 @@ def nbytes(self) -> int: ) @property - def sizes(self): + def sizes(self) -> Mapping[Hashable, int]: """Mapping from dimension names to lengths for all PredictionEnsemble._datasets.""" pe_dims = dict(self.get_initialized().dims) for ds in self._datasets.values(): @@ -216,12 +237,12 @@ def sizes(self): return pe_dims @property - def dims(self): + def dims(self) -> Mapping[Hashable, int]: """Mapping from dimension names to lengths all PredictionEnsemble._datasets.""" return Frozen(self.sizes) @property - def chunks(self): + def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: """Mapping from chunks all PredictionEnsemble._datasets.""" pe_chunks = dict(self.get_initialized().chunks) for ds in self._datasets.values(): @@ -232,7 +253,16 @@ def chunks(self): return Frozen(pe_chunks) @property - def data_vars(self): + def chunksizes(self) -> Mapping[Hashable, Tuple[int, ...]]: + """Mapping from dimension names to block lengths for this dataset's data, or None if + the underlying data is not a dask array. + Cannot be modified directly, but can be modified by calling .chunk(). + Same as Dataset.chunks. + """ + return self.chunks + + @property + def data_vars(self) -> DataVariables: """Dictionary of DataArray objects corresponding to data variables available in all PredictionEnsemble._datasets.""" varset = set(self.get_initialized().data_vars) for ds in self._datasets.values(): @@ -244,21 +274,21 @@ def data_vars(self): # when you just print it interactively # https://stackoverflow.com/questions/1535327/how-to-print-objects-of-class-using-print - def __repr__(self): + def __repr__(self) -> str: if XR_OPTIONS["display_style"] == "html": return _display_metadata_html(self) else: return _display_metadata(self) - def __len__(self): + def __len__(self) -> int: """Number of all variables in all PredictionEnsemble._datasets.""" return len(self.data_vars) - def __iter__(self): + def __iter__(self) -> Iterator[Hashable]: """Iterate over underlying xr.Datasets for initialized, uninitialized, observations.""" return iter(self._datasets.values()) - def __delitem__(self, key): + def __delitem__(self, key: Hashable) -> None: """Remove a variable from this PredictionEnsemble.""" del self._datasets["initialized"][key] for ds in self._datasets.values(): @@ -266,7 +296,7 @@ def __delitem__(self, key): if key in ds.data_vars: del ds[key] - def __contains__(self, key): + def __contains__(self, key: Hashable) -> bool: """The 'in' operator will return true or false depending on whether 'key' is an array in all PredictionEnsemble._datasets or not. """ @@ -277,7 +307,7 @@ def __contains__(self, key): contained = False return contained - def equals(self, other): + def equals(self, other: Union["PredictionEnsemble", Any]) -> bool: """Two PredictionEnsembles are equal if they have matching variables and coordinates, all of which are equal. PredictionEnsembles can still be equal (like pandas objects) if they have NaN @@ -298,7 +328,7 @@ def equals(self, other): return False return equal - def identical(self, other): + def identical(self, other: Union["PredictionEnsemble", Any]) -> bool: """Like equals, but also checks all dataset attributes and the attributes on all variables and coordinates.""" if not isinstance(other, PredictionEnsemble): @@ -362,7 +392,11 @@ def plot(self, variable=None, ax=None, show_members=False, cmap=None, x="time"): self, variable=variable, ax=ax, show_members=show_members, cmap=cmap ) - def _math(self, other, operator): + def _math( + self, + other, + operator, + ): """Helper function for __add__, __sub__, __mul__, __truediv__. Allows math operations with type: @@ -411,7 +445,7 @@ def div(a, b): ) # catch other dimensions in other if isinstance(other, tuple([xr.Dataset, xr.DataArray])): - if not set(other.dims).issubset(self._datasets["initialized"].dims): + if not set(other.dims).issubset(self._datasets["initialized"].dims): # type: ignore raise DimensionError(f"{error_str} containing new dimensions.") # catch xr.Dataset with different data_vars if isinstance(other, xr.Dataset): @@ -434,7 +468,7 @@ def div(a, b): if isinstance(other._datasets[dataset], xr.Dataset) and isinstance( self._datasets[dataset], xr.Dataset ): - datasets[dataset] = operator( + datasets[dataset] = operator( # type: ignore datasets[dataset], other._datasets[dataset] ) return self._construct_direct(datasets, kind=self.kind) @@ -453,7 +487,7 @@ def __mul__(self, other): def __truediv__(self, other): return self._math(other, operator="div") - def __getitem__(self, varlist): + def __getitem__(self, varlist: Union[str, List[str]]) -> "PredictionEnsemble": """Allows subsetting data variable from PredictionEnsemble as from xr.Dataset. Args: @@ -474,7 +508,7 @@ def sel_vars(ds, varlist): return self._apply_func(sel_vars, varlist) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Callable[[VarArg(Any), KwArg(Any)], Any]: """Allows for xarray methods to be applied to our prediction objects. Args: @@ -562,7 +596,9 @@ def _construct_direct(cls, datasets, kind): obj._warn_if_chunked_along_init_member_time() return obj - def _apply_func(self, func, *args, **kwargs): + def _apply_func( + self, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> "PredictionEnsemble": """Apply a function to all datasets in a `PredictionEnsemble`.""" # Create temporary copy to modify to avoid inplace operation. datasets = self._datasets.copy() @@ -591,15 +627,20 @@ def _apply_func(self, func, *args, **kwargs): # Instantiates new object with the modified datasets. return self._construct_direct(datasets, kind=self.kind) - def get_initialized(self): + def get_initialized(self) -> xr.Dataset: """Returns the xarray dataset for the initialized ensemble.""" return self._datasets["initialized"] - def get_uninitialized(self): + def get_uninitialized(self) -> xr.Dataset: """Returns the xarray dataset for the uninitialized ensemble.""" return self._datasets["uninitialized"] - def smooth(self, smooth_kws=None, how="mean", **xesmf_kwargs): + def smooth( + self, + smooth_kws=None, + how="mean", + **xesmf_kwargs, + ): """Smooth all entries of PredictionEnsemble in the same manner to be able to still calculate prediction skill afterwards. @@ -724,7 +765,9 @@ def smooth(self, smooth_kws=None, how="mean", **xesmf_kwargs): ) return self - def remove_seasonality(self, initialized_dim="init", seasonality=None): + def remove_seasonality( + self, initialized_dim: str = "init", seasonality: Union[None, str] = None + ) -> "PredictionEnsemble": """Remove seasonal cycle from all climpred datasets. Args: @@ -811,7 +854,7 @@ class PerfectModelEnsemble(PredictionEnsemble): be an `xarray` Dataset or DataArray. """ - def __init__(self, xobj): + def __init__(self, xobj: Union[xr.DataArray, xr.Dataset]) -> None: """Create a `PerfectModelEnsemble` object by inputting output from the control run in `xarray` format. @@ -1350,7 +1393,7 @@ class HindcastEnsemble(PredictionEnsemble): be an `xarray` Dataset or DataArray. """ - def __init__(self, xobj): + def __init__(self, xobj: Union[xr.DataArray, xr.Dataset]) -> None: """Create a `HindcastEnsemble` object by inputting output from a prediction ensemble in `xarray` format. @@ -1411,7 +1454,9 @@ def _vars_to_drop(self, init=True): return init_vars_to_drop, obs_vars_to_drop @is_xarray(1) - def add_observations(self, xobj): + def add_observations( + self, xobj: Union[xr.DataArray, xr.Dataset] + ) -> "HindcastEnsemble": """Add verification data against which to verify the initialized ensemble. Args: diff --git a/climpred/versioning/print_versions.py b/climpred/versioning/print_versions.py index ca55130d3..c0f615601 100644 --- a/climpred/versioning/print_versions.py +++ b/climpred/versioning/print_versions.py @@ -106,10 +106,7 @@ def show_versions(as_json=False): deps_blob.append((modname, "installed")) if as_json: - try: - import json - except Exception: - import simplejson as json + import json j = dict(system=dict(sys_info), dependencies=dict(deps_blob)) diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 02066a062..45109a1b4 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -148,6 +148,9 @@ Preparing Pull Requests needed, or will generally be quite clear about what you need to do to pass the commit test. + ``pre-commit`` also runs `mypy `_ for static type checking on + `type hints `_. + #. Break your edits up into reasonably sized commits:: $ git commit -a -m "" diff --git a/setup.cfg b/setup.cfg index e6b22b57b..c27bc77c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,3 +55,43 @@ markers = [aliases] test = pytest + + +[mypy] +exclude = asv_bench|doc +files = . +show_error_codes = True + +# Most of the numerical computing stack doesn't have type annotations yet. +[mypy-bottleneck.*] +ignore_missing_imports = True +[mypy-cftime.*] +ignore_missing_imports = True +[mypy-dask.*] +ignore_missing_imports = True +[mypy-distributed.*] +ignore_missing_imports = True +[mypy-fsspec.*] +ignore_missing_imports = True +[mypy-matplotlib.*] +ignore_missing_imports = True +[mypy-nc_time_axis.*] +ignore_missing_imports = True +[mypy-numpy.*] +ignore_missing_imports = True +[mypy-netCDF4.*] +ignore_missing_imports = True +[mypy-pandas.*] +ignore_missing_imports = True +[mypy-pytest.*] +ignore_missing_imports = True +[mypy-scipy.*] +ignore_missing_imports = True +[mypy-setuptools] +ignore_missing_imports = True +[mypy-toolz.*] +ignore_missing_imports = True +# version spanning code is hard to type annotate (and most of this module will +# be going away soon anyways) +[mypy-xarray.core.pycompat] +ignore_errors = True From d8d9e6daf77d6626424e27985202f78000a3a702 Mon Sep 17 00:00:00 2001 From: AS Date: Sun, 28 Nov 2021 12:42:14 +0100 Subject: [PATCH 02/25] verify --- climpred/classes.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index 5912c3795..c08225e42 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -926,7 +926,9 @@ def _vars_to_drop(self, init=True): return init_vars_to_drop, ctrl_vars_to_drop @is_xarray(1) - def add_control(self, xobj): + def add_control( + self, xobj: Union[xr.DataArray, xr.Dataset] + ) -> "PerfectModelEnsemble": """Add the control run that initialized the climate prediction ensemble. @@ -948,7 +950,7 @@ def add_control(self, xobj): datasets.update({"control": xobj}) return self._construct_direct(datasets, kind="perfect") - def generate_uninitialized(self): + def generate_uninitialized(self) -> "PerfectModelEnsemble": """Generate an uninitialized ensemble by bootstrapping the initialized prediction ensemble. @@ -967,19 +969,29 @@ def generate_uninitialized(self): datasets.update({"uninitialized": uninit}) return self._construct_direct(datasets, kind="perfect") - def get_control(self): + def get_control(self) -> xr.Dataset: """Returns the control as an xarray dataset.""" return self._datasets["control"] + from .comparison import Comparison + from .metrics import Metric + + metricType = Union[str, Metric] + comparisonType = Union[str, Comparison] + dimType = Optional[Union[Hashable, Iterable[Hashable]]] + referenceType = Optional[Union[List[str], str]] + groupbyType = Optional[Union[str, xr.DataArray]] + metric_kwargsType = Optional[Any] + def verify( self, - metric=None, - comparison=None, - dim=None, - reference=None, - groupby=None, - **metric_kwargs, - ): + metric: metricType = None, + comparison: comparisonType = None, + dim: dimType = None, + reference: referenceType = None, + groupby: groupbyType = None, + **metric_kwargs: metric_kwargsType, + ) -> xr.Dataset: """Verify initialized predictions against a configuration of other ensemble members. .. note:: @@ -1089,7 +1101,7 @@ def verify( ref_compute_kwargs["comparison"] = comparison ref = getattr(self, f"_compute_{r}")(**ref_compute_kwargs) result = xr.concat([result, ref], dim="skill", **CONCAT_KWARGS) - result = result.assign_coords(skill=["initialized"] + reference) + result = result.assign_coords(skill=["initialized"] + reference) # type: ignore return result.squeeze() def _compute_uninitialized( From f3def91c99dd4e2f4e2ed87b54e389a31875ae5c Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 17:16:07 +0100 Subject: [PATCH 03/25] mypy bootstrap and verify --- climpred/checks.py | 5 ++- climpred/classes.py | 88 ++++++++++++++++++++++++--------------------- 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/climpred/checks.py b/climpred/checks.py index 831ba97c6..6dc8e91d7 100644 --- a/climpred/checks.py +++ b/climpred/checks.py @@ -307,7 +307,10 @@ def warn_if_chunking_would_increase_performance(ds, crit_size_in_MB=100): ) -def _check_valid_reference(reference): +from typing import List, Optional, Union + + +def _check_valid_reference(reference: Optional[Union[List[str], str]]) -> List: """Enforce reference as list and check for valid entries.""" if reference is None: reference = [] diff --git a/climpred/classes.py b/climpred/classes.py index c08225e42..5ef39f3e0 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -3,6 +3,7 @@ from typing import ( # TYPE_CHECKING,; Collection,; DefaultDict,; MutableMapping, Any, Callable, + Collection, Dict, Hashable, Iterable, @@ -52,6 +53,7 @@ match_initialized_vars, rename_to_climpred_dims, ) +from .comparison import Comparison from .constants import ( BIAS_CORRECTION_BIAS_CORRECTION_METHODS, BIAS_CORRECTION_TRAIN_TEST_SPLIT_METHODS, @@ -65,6 +67,7 @@ from .exceptions import DimensionError, VariableError from .graphics import plot_ensemble_perfect_model, plot_lead_timeseries_hindcast from .logging import log_compute_hindcast_header +from .metrics import Metric from .options import OPTIONS from .prediction import ( _apply_metric_at_given_lead, @@ -85,6 +88,14 @@ convert_Timedelta_to_lead_units, ) +metricType = Union[str, Metric] +comparisonType = Union[str, Comparison] +dimType = Optional[Union[str, List[str]]] +alignmentType = str +referenceType = Union[List[str], str] +groupbyType = Optional[Union[str, xr.DataArray]] +metric_kwargsType = Optional[Any] + def _display_metadata(self) -> str: """ @@ -973,16 +984,6 @@ def get_control(self) -> xr.Dataset: """Returns the control as an xarray dataset.""" return self._datasets["control"] - from .comparison import Comparison - from .metrics import Metric - - metricType = Union[str, Metric] - comparisonType = Union[str, Comparison] - dimType = Optional[Union[Hashable, Iterable[Hashable]]] - referenceType = Optional[Union[List[str], str]] - groupbyType = Optional[Union[str, xr.DataArray]] - metric_kwargsType = Optional[Any] - def verify( self, metric: metricType = None, @@ -1250,16 +1251,17 @@ def _compute_climatology( def bootstrap( self, - metric=None, - comparison=None, - dim=None, - reference=None, - iterations=None, - sig=95, - pers_sig=None, - groupby=None, - **metric_kwargs, - ): + metric: metricType = None, + comparison: comparisonType = None, + dim: dimType = None, + reference: referenceType = None, + groupby: groupbyType = None, + iterations: Optional[int] = None, + sig: int = 95, + resample_dim: str = "member", + pers_sig: Optional[int] = None, + **metric_kwargs: metric_kwargsType, + ) -> xr.Dataset: """Bootstrap with replacement according to Goddard et al. 2013. Args: @@ -1278,6 +1280,7 @@ def bootstrap( If None or empty, returns no p value. iterations (int): Number of resampling iterations for bootstrapping with replacement. Recommended >= 500. + resample_dim (str): dimension for resampling sig (int, default 95): Significance level in percent for deciding whether uninitialized and persistence beat initialized skill. pers_sig (int): If not ``None``, the separate significance level for @@ -1358,6 +1361,7 @@ def bootstrap( comparison=comparison, dim=dim, iterations=iterations, + resample_dim=resample_dim, sig=sig, pers_sig=pers_sig, **metric_kwargs, @@ -1521,14 +1525,14 @@ def get_observations(self): def verify( self, - reference=None, - metric=None, - comparison=None, - dim=None, - alignment=None, - groupby=None, - **metric_kwargs, - ): + metric: metricType = None, + comparison: comparisonType = None, + dim: dimType = None, + alignment: alignmentType = None, + reference: referenceType = None, + groupby: groupbyType = None, + **metric_kwargs: metric_kwargsType, + ) -> xr.Dataset: """Verifies the initialized ensemble against observations. .. note:: @@ -1601,6 +1605,10 @@ def verify( Data variables: SST (skill, lead) float64 0.9023 0.8807 0.8955 ... 0.9078 0.9128 0.9159 """ + if isinstance(reference, str): + reference = [reference] + else: + pass # reference = list(reference) if groupby is not None: skill_group = [] group_label = [] @@ -1761,18 +1769,18 @@ def _verify( def bootstrap( self, - metric=None, - comparison=None, - dim=None, - alignment=None, - reference=None, - iterations=None, - sig=95, - resample_dim="member", - pers_sig=None, - groupby=None, - **metric_kwargs, - ): + metric: metricType = None, + comparison: comparisonType = None, + dim: dimType = None, + alignment: alignmentType = None, + reference: referenceType = None, + groupby: groupbyType = None, + iterations: Optional[int] = None, + sig: int = 95, + resample_dim: str = "member", + pers_sig: Optional[int] = None, + **metric_kwargs: metric_kwargsType, + ) -> xr.Dataset: """Bootstrap with replacement according to Goddard et al. 2013. Args: From a5104af38b47fa3b4af6ba27e171b6fda4cfbe2e Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 17:21:57 +0100 Subject: [PATCH 04/25] add mypy to env --- ci/requirements/climpred-dev.yml | 2 +- climpred/classes.py | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ci/requirements/climpred-dev.yml b/ci/requirements/climpred-dev.yml index 72a0aed89..d9be284f6 100644 --- a/ci/requirements/climpred-dev.yml +++ b/ci/requirements/climpred-dev.yml @@ -30,7 +30,7 @@ dependencies: - black==19.10b0 - coveralls - flake8 - #- importlib_metadata + - mypy - isort - pre-commit - pylint diff --git a/climpred/classes.py b/climpred/classes.py index 5ef39f3e0..a6817ff63 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -1494,7 +1494,9 @@ def add_observations( return self._construct_direct(datasets, kind="hindcast") @is_xarray(1) - def add_uninitialized(self, xobj): + def add_uninitialized( + self, xobj: Union[xr.DataArray, xr.Dataset] + ) -> "HindcastEnsemble": """Add a companion uninitialized ensemble for comparison to verification data. Args: @@ -1515,7 +1517,7 @@ def add_uninitialized(self, xobj): datasets.update({"uninitialized": xobj}) return self._construct_direct(datasets, kind="hindcast") - def get_observations(self): + def get_observations(self) -> xr.Dataset: """Returns xarray Datasets of the observations/verification data. Returns: @@ -1936,14 +1938,14 @@ def bootstrap( def remove_bias( self, - alignment=None, - how="additive_mean", - train_test_split="unfair", - train_init=None, - train_time=None, - cv=False, - **metric_kwargs, - ): + alignment: alignmentType = None, + how: str = "additive_mean", + train_test_split: str = "unfair", + train_init: Optional[Union[xr.DataArray, slice]] = None, + train_time: Optional[Union[xr.DataArray, slice]] = None, + cv: Union[bool, str] = False, + **metric_kwargs: metric_kwargsType, + ) -> "HindcastEnsemble": """Calculate and remove bias from :py:class:`~climpred.classes.HindcastEnsemble`. Bias is grouped by ``seasonality`` set via :py:class:`~climpred.options.set_options`. When wrapping xclim.sbda.adjustment use ``group`` instead. From fc15ecc5c51f9208f6151d6222e56622488e1b1e Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 17:44:26 +0100 Subject: [PATCH 05/25] typing graphics --- climpred/classes.py | 15 +++++++++--- climpred/graphics.py | 58 ++++++++++++++++++++++++++------------------ 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index a6817ff63..b8f6a57cf 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -355,7 +355,16 @@ def identical(self, other: Union["PredictionEnsemble", Any]) -> bool: return False return id - def plot(self, variable=None, ax=None, show_members=False, cmap=None, x="time"): + import matplotlib.pyplot as plt + + def plot( + self, + variable: Optional[str] = None, + ax: Optional[plt.Axes] = None, + show_members: bool = False, + cmap: Optional[str] = None, + x: str = "time", + ) -> plt.Axes: """Plot datasets from PredictionEnsemble. Args: @@ -385,7 +394,7 @@ def plot(self, variable=None, ax=None, show_members=False, cmap=None, x="time"): if x == "time": x = "valid_time" assert x in ["valid_time", "init"] - if self.kind == "hindcast": + if isinstance(self, HindcastEnsemble): if cmap is None: cmap = "viridis" return plot_lead_timeseries_hindcast( @@ -396,7 +405,7 @@ def plot(self, variable=None, ax=None, show_members=False, cmap=None, x="time"): cmap=cmap, x=x, ) - elif self.kind == "perfect": + elif isinstance(self, PerfectModelEnsemble): if cmap is None: cmap = "tab10" return plot_ensemble_perfect_model( diff --git a/climpred/graphics.py b/climpred/graphics.py index b6605d805..18237bac3 100644 --- a/climpred/graphics.py +++ b/climpred/graphics.py @@ -1,11 +1,13 @@ import warnings from collections import OrderedDict +from typing import Optional, Tuple, Union import numpy as np import xarray as xr from xarray.coding.times import infer_calendar_name from .checks import DimensionError +from .classes import HindcastEnsemble, PerfectModelEnsemble from .constants import CLIMPRED_DIMS from .metrics import ALL_METRICS, PROBABILISTIC_METRICS from .utils import get_lead_cftime_shift_args, get_metric_class, shift_cftime_index @@ -79,17 +81,17 @@ def plot_relative_entropy(rel_ent, rel_ent_threshold=None, **kwargs): def plot_bootstrapped_skill_over_leadyear( - bootstrapped, - ax=None, - color_initialized="indianred", - color_uninitialized="steelblue", - color_persistence="gray", - color_climatology="tan", - capsize=4, - fontsize=8, - figsize=(10, 4), - fmt="--o", -): + bootstrapped: xr.Dataset, + ax: Optional[plt.Axes] = None, + color_initialized: str = "indianred", + color_uninitialized: str = "steelblue", + color_persistence: str = "gray", + color_climatology: str = "tan", + capsize: Union[int, float] = 4, + fontsize: Union[int, float] = 8, + figsize: Tuple = (10, 4), + fmt: str = "--o", +) -> plt.Axes: """ Plot Ensemble Prediction skill as in Li et al. 2016 Fig.3a-c. @@ -209,8 +211,13 @@ def _check_only_climpred_dims(pe): def plot_lead_timeseries_hindcast( - he, variable=None, ax=None, show_members=False, cmap="viridis", x="time" -): + he: HindcastEnsemble, + variable: Optional[str] = None, + ax: Optional[plt.Axes] = None, + show_members: bool = False, + cmap: Optional[str] = "viridis", + x: str = "time", +) -> plt.Axes: """Plot datasets from HindcastEnsemble. Args: @@ -238,13 +245,13 @@ def plot_lead_timeseries_hindcast( if isinstance(obs, xr.Dataset): obs = obs[variable] - cmap = mpl.cm.get_cmap(cmap, hind.lead.size) + _cmap = mpl.cm.get_cmap(cmap, hind.lead.size) if ax is None: _, ax = plt.subplots(figsize=(10, 4)) if isinstance(hist, xr.DataArray) and x == "valid_time": if "member" in hist.dims and not show_members: hist = hist.mean("member") - member_alpha = 1 + member_alpha = 1.0 lw = 2 else: member_alpha = 0.4 @@ -263,14 +270,14 @@ def plot_lead_timeseries_hindcast( h = hind.sel(lead=lead) if not show_members and "member" in h.dims: h = h.mean("member") - lead_alpha = 1 + lead_alpha = 1.0 else: lead_alpha = 0.5 h.plot( ax=ax, x=x, hue="member", - color=cmap(i), + color=_cmap(i), label=f"initialized: lead={lead} {hind.lead.attrs['units'][:-1]}", alpha=lead_alpha, zorder=hind.lead.size - i, @@ -299,8 +306,13 @@ def plot_lead_timeseries_hindcast( def plot_ensemble_perfect_model( - pm, variable=None, ax=None, show_members=False, cmap="tab10" -): + pm: PerfectModelEnsemble, + variable: Optional[str] = None, + ax: Optional[plt.Axes] = None, + show_members: bool = False, + cmap: Optional[str] = "tab10", + x: str = "time", +) -> plt.Axes: """Plot datasets from PerfectModelEnsemble. Args: @@ -335,7 +347,7 @@ def plot_ensemble_perfect_model( if ax is None: _, ax = plt.subplots(figsize=(10, 4)) - cmap = mpl.cm.get_cmap(cmap, initialized.init.size) + _cmap = mpl.cm.get_cmap(cmap, initialized.init.size) for ii, i in enumerate(initialized.init.values): dsi = initialized.sel(init=i) @@ -345,7 +357,7 @@ def plot_ensemble_perfect_model( dsi = dsi.mean("member") if uninitialized_present: dsu = dsu.mean("member") - member_alpha = 1 + member_alpha = 1.0 lw = 2 labelstr = "ensemble mean" else: @@ -362,12 +374,12 @@ def plot_ensemble_perfect_model( ) # plot ensemble mean, first white then color to highlight ensemble mean dsi.mean("member").plot(ax=ax, x=x, color="white", lw=3, zorder=10) - dsi.mean("member").plot(ax=ax, x=x, color=cmap(ii), lw=2, zorder=11) + dsi.mean("member").plot(ax=ax, x=x, color=_cmap(ii), lw=2, zorder=11) dsi.plot( ax=ax, x=x, hue="member", - color=cmap(ii), + color=_cmap(ii), alpha=member_alpha, lw=lw, label=labelstr, From e9af2f8dfae71b2363c52ac57c6483b7fcfd79fb Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 17:53:43 +0100 Subject: [PATCH 06/25] mypy metrics and comparisons --- climpred/comparisons.py | 21 ++++++++++++--------- climpred/metrics.py | 34 +++++++++++++++++++--------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/climpred/comparisons.py b/climpred/comparisons.py index 9bdfcf42d..46c7ed50b 100644 --- a/climpred/comparisons.py +++ b/climpred/comparisons.py @@ -19,7 +19,7 @@ def _transpose_and_rechunk_to(new_chunk_ds, ori_chunk_ds): ) -def _display_comparison_metadata(self): +def _display_comparison_metadata(self) -> str: summary = "----- Comparison metadata -----\n" summary += f"Name: {self.name}\n" # probabilistic or only deterministic @@ -33,18 +33,21 @@ def _display_comparison_metadata(self): return summary +from typing import Callable, List, Optional + + class Comparison: """Master class for all comparisons.""" def __init__( self, - name, - function, - hindcast, - probabilistic, - long_name=None, - aliases=None, - ): + name: str, + function: Callable, + hindcast: bool, + probabilistic: bool, + long_name: Optional[str] = None, + aliases: Optional[List[str]] = None, + ) -> None: """Comparison initialization. Args: @@ -72,7 +75,7 @@ def __init__( self.long_name = long_name self.aliases = aliases - def __repr__(self): + def __repr__(self) -> str: """Show metadata of comparison class.""" return _display_comparison_metadata(self) diff --git a/climpred/metrics.py b/climpred/metrics.py index 190a7896c..a0f0cee10 100644 --- a/climpred/metrics.py +++ b/climpred/metrics.py @@ -163,7 +163,7 @@ def _maybe_member_mean_reduce_dim(forecast, dim): return forecast, dim -def _display_metric_metadata(self): +def _display_metric_metadata(self) -> str: summary = "----- Metric metadata -----\n" summary += f"Name: {self.name}\n" summary += f"Alias: {self.aliases}\n" @@ -189,32 +189,36 @@ def _display_metric_metadata(self): return summary +from typing import Callable, List, Optional + + class Metric: """Master class for all metrics.""" def __init__( self, - name, - function, - positive, - probabilistic, - unit_power, - long_name=None, - aliases=None, - minimum=None, - maximum=None, - perfect=None, - normalize=False, - allows_logical=False, - requires_member_dim=False, + name: str, + function: Callable, + positive: Optional[bool], + probabilistic: bool, + unit_power: float, + long_name: Optional[str] = None, + aliases: Optional[List[str]] = None, + minimum: Optional[float] = None, + maximum: Optional[float] = None, + perfect: Optional[float] = None, + normalize: bool = False, + allows_logical: bool = False, + requires_member_dim: bool = False, ): """Metric initialization. Args: name (str): name of metric. function (function): metric function. - positive (bool): Is metric positively oriented? If True, higher metric + positive (bool or None): Is metric positively oriented? If True, higher metric value means better skill. If False, lower metric value means better skill. + None if different differentiation. probabilistic (bool): Is metric probabilistic? `False` means deterministic. unit_power (float, int): Power of the unit of skill based on unit From 447c0ca4201b77633323e00e2c22e9ce4f18fed2 Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 18:00:23 +0100 Subject: [PATCH 07/25] mypy math --- climpred/classes.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index b8f6a57cf..5294b5684 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -412,10 +412,12 @@ def plot( self, variable=variable, ax=ax, show_members=show_members, cmap=cmap ) + mathType = Union[int, float, np.ndarray, xr.DataArray, xr.Dataset] + def _math( self, - other, - operator, + other: mathType, + operator: str, ): """Helper function for __add__, __sub__, __mul__, __truediv__. @@ -478,7 +480,7 @@ def div(a, b): f"other.data_vars = {list(other.data_vars)}." ) - operator = eval(operator) + _operator = eval(operator) if isinstance(other, PredictionEnsemble): # Create temporary copy to modify to avoid inplace operation. @@ -488,23 +490,23 @@ def div(a, b): if isinstance(other._datasets[dataset], xr.Dataset) and isinstance( self._datasets[dataset], xr.Dataset ): - datasets[dataset] = operator( # type: ignore + datasets[dataset] = _operator( datasets[dataset], other._datasets[dataset] ) return self._construct_direct(datasets, kind=self.kind) else: - return self._apply_func(operator, other) + return self._apply_func(_operator, other) - def __add__(self, other): + def __add__(self, other: mathType) -> "PredictionEnsemble": return self._math(other, operator="add") - def __sub__(self, other): + def __sub__(self, other: mathType) -> "PredictionEnsemble": return self._math(other, operator="sub") - def __mul__(self, other): + def __mul__(self, other: mathType) -> "PredictionEnsemble": return self._math(other, operator="mul") - def __truediv__(self, other): + def __truediv__(self, other: mathType) -> "PredictionEnsemble": return self._math(other, operator="div") def __getitem__(self, varlist: Union[str, List[str]]) -> "PredictionEnsemble": From 834219642c25a86cf00671f95d9c7948f7c1312d Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 18:06:44 +0100 Subject: [PATCH 08/25] fix comparisons --- climpred/comparisons.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/climpred/comparisons.py b/climpred/comparisons.py index 46c7ed50b..3d5200c8b 100644 --- a/climpred/comparisons.py +++ b/climpred/comparisons.py @@ -1,9 +1,12 @@ +from typing import Callable, List, Optional + import dask import numpy as np import xarray as xr from .checks import has_dims, has_min_len from .constants import M2M_MEMBER_DIM +from .metrics import Metric def _transpose_and_rechunk_to(new_chunk_ds, ori_chunk_ds): @@ -33,16 +36,13 @@ def _display_comparison_metadata(self) -> str: return summary -from typing import Callable, List, Optional - - class Comparison: """Master class for all comparisons.""" def __init__( self, name: str, - function: Callable, + function: Callable[..., Optional[Metric]], hindcast: bool, probabilistic: bool, long_name: Optional[str] = None, From 76e789b8ae46205dfe7c9fc0cd4fd9379ca2e27b Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 19:06:34 +0100 Subject: [PATCH 09/25] fix comparison typing --- climpred/comparisons.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/climpred/comparisons.py b/climpred/comparisons.py index 3d5200c8b..5616344c8 100644 --- a/climpred/comparisons.py +++ b/climpred/comparisons.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Tuple import dask import numpy as np @@ -42,7 +42,9 @@ class Comparison: def __init__( self, name: str, - function: Callable[..., Optional[Metric]], + function: Callable[ + [xr.Dataset, Optional[Metric]], Tuple[xr.Dataset, xr.Dataset] + ], hindcast: bool, probabilistic: bool, long_name: Optional[str] = None, From d8f5a29d32446b714caf790ad1f3acfba8eede3a Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 20:27:06 +0100 Subject: [PATCH 10/25] classes all but smooth --- climpred/classes.py | 44 +++++++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index 5294b5684..edc9845ce 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -623,6 +623,7 @@ def _apply_func( ) -> "PredictionEnsemble": """Apply a function to all datasets in a `PredictionEnsemble`.""" # Create temporary copy to modify to avoid inplace operation. + # isnt that essentially the same as .map(func)? datasets = self._datasets.copy() # More explicit than nested dictionary comprehension. @@ -831,7 +832,7 @@ def _remove_seasonality(ds, initialized_dim="init", seasonality=None): seasonality=seasonality, ) - def _warn_if_chunked_along_init_member_time(self): + def _warn_if_chunked_along_init_member_time(self) -> None: """Warn upon instantiation when CLIMPRED_DIMS except ``lead`` are chunked with more than one chunk to show how to circumvent ``xskillscore`` chunking ``ValueError``.""" @@ -896,7 +897,12 @@ def __init__(self, xobj: Union[xr.DataArray, xr.Dataset]) -> None: self._datasets.update({"control": {}}) self.kind = "perfect" - def _apply_climpred_function(self, func, input_dict=None, **kwargs): + def _apply_climpred_function( + self, + func: Callable[..., Any], + input_dict: Dict[str, Any], + **kwargs: Any, + ): """Helper function to loop through observations and apply an arbitrary climpred function. @@ -916,7 +922,7 @@ def _apply_climpred_function(self, func, input_dict=None, **kwargs): control = control.drop_vars(ctrl_vars) return func(ensemble, control, **kwargs) - def _vars_to_drop(self, init=True): + def _vars_to_drop(self, init: bool = True) -> Tuple[List[str], List[str]]: """Returns list of variables to drop when comparing initialized/uninitialized to a control. @@ -1113,12 +1119,16 @@ def verify( ref_compute_kwargs["comparison"] = comparison ref = getattr(self, f"_compute_{r}")(**ref_compute_kwargs) result = xr.concat([result, ref], dim="skill", **CONCAT_KWARGS) - result = result.assign_coords(skill=["initialized"] + reference) # type: ignore + result = result.assign_coords(skill=["initialized"] + reference) return result.squeeze() def _compute_uninitialized( - self, metric=None, comparison=None, dim=None, **metric_kwargs - ): + self, + metric: metricType = None, + comparison: comparisonType = None, + dim: dimType = None, + **metric_kwargs: metric_kwargsType, + ) -> xr.Dataset: """Verify the bootstrapped uninitialized run against itself. .. note:: @@ -1169,7 +1179,12 @@ def _compute_uninitialized( res["lead"].attrs = self.get_initialized().lead.attrs return res - def _compute_persistence(self, metric=None, dim=None, **metric_kwargs): + def _compute_persistence( + self, + metric: metricType = None, + dim: dimType = None, + **metric_kwargs: metric_kwargsType, + ): """Verify a simple persistence forecast of the control run against itself. Args: @@ -1215,8 +1230,12 @@ def _compute_persistence(self, metric=None, dim=None, **metric_kwargs): return res def _compute_climatology( - self, metric=None, comparison=None, dim=None, **metric_kwargs - ): + self, + metric: metricType = None, + comparison: comparisonType = None, + dim: dimType = None, + **metric_kwargs: metric_kwargsType, + ) -> xr.Dataset: """Verify a climatology forecast of the control run against itself. Args: @@ -1438,7 +1457,9 @@ def __init__(self, xobj: Union[xr.DataArray, xr.Dataset]) -> None: self._datasets.update({"observations": {}}) self.kind = "hindcast" - def _apply_climpred_function(self, func, init, **kwargs): + def _apply_climpred_function( + self, func: Callable[..., Any], init: bool, **kwargs: Any + ) -> "HindcastEnsemble": """Helper function to loop through verification data and apply an arbitrary climpred function. @@ -1446,12 +1467,13 @@ def _apply_climpred_function(self, func, init, **kwargs): func (function): climpred function to apply to object. init (bool): Whether or not it's the initialized ensemble. """ + # fixme: essentially the same as map? hind = self._datasets["initialized"] verif = self._datasets["observations"] drop_init, drop_obs = self._vars_to_drop(init=init) return func(hind.drop_vars(drop_init), verif.drop_vars(drop_obs), **kwargs) - def _vars_to_drop(self, init=True): + def _vars_to_drop(self, init: bool = True) -> Tuple[List[str], List[str]]: """Returns list of variables to drop when comparing initialized/uninitialized to observations. From 4b440a92c8bb14f822b77ecbbc163f12541efeb5 Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 20:30:07 +0100 Subject: [PATCH 11/25] add mypy_ext to doc env --- ci/requirements/docs.yml | 1 + climpred/classes.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ci/requirements/docs.yml b/ci/requirements/docs.yml index a67e573ab..493830743 100644 --- a/ci/requirements/docs.yml +++ b/ci/requirements/docs.yml @@ -7,6 +7,7 @@ dependencies: - importlib_metadata - matplotlib-base - nbsphinx + - mypy_extensions - nc-time-axis - netcdf4 - sphinx diff --git a/climpred/classes.py b/climpred/classes.py index edc9845ce..dae72f080 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -53,7 +53,7 @@ match_initialized_vars, rename_to_climpred_dims, ) -from .comparison import Comparison +from .comparisons import Comparison from .constants import ( BIAS_CORRECTION_BIAS_CORRECTION_METHODS, BIAS_CORRECTION_TRAIN_TEST_SPLIT_METHODS, From a03c1ef602702eb918b2b9c2b2ee9b8fa7124786 Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 20:31:14 +0100 Subject: [PATCH 12/25] add mypy_ext to docs_notebook --- ci/requirements/docs_notebooks.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/requirements/docs_notebooks.yml b/ci/requirements/docs_notebooks.yml index 3c4331dd4..471957643 100644 --- a/ci/requirements/docs_notebooks.yml +++ b/ci/requirements/docs_notebooks.yml @@ -18,6 +18,7 @@ dependencies: - sphinxcontrib-napoleon - sphinx_rtd_theme - sphinx-copybutton + - mypy_extensions - toolz - xarray>=0.16.1 - esmtools>=1.1.3 From 4eaded3b2484f7302e77d0c773aa5d3fce6d53be Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 20:34:19 +0100 Subject: [PATCH 13/25] addd mypy to all envs --- ci/requirements/climpred-dev.yml | 1 + ci/requirements/maximum-tests.yml | 1 + ci/requirements/minimum-tests.yml | 1 + climpred/classes.py | 5 +---- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ci/requirements/climpred-dev.yml b/ci/requirements/climpred-dev.yml index d9be284f6..867a4ccec 100644 --- a/ci/requirements/climpred-dev.yml +++ b/ci/requirements/climpred-dev.yml @@ -37,6 +37,7 @@ dependencies: - pytest - pytest-cov - pytest-sugar + - mypy_extensions # Performance - bottleneck - numba diff --git a/ci/requirements/maximum-tests.yml b/ci/requirements/maximum-tests.yml index 609f17beb..c98e0bdd5 100644 --- a/ci/requirements/maximum-tests.yml +++ b/ci/requirements/maximum-tests.yml @@ -17,6 +17,7 @@ dependencies: - pytest - pytest-cov - pytest-xdist + - mypy_extensions - scipy - xarray>=0.16.1 - xesmf diff --git a/ci/requirements/minimum-tests.yml b/ci/requirements/minimum-tests.yml index bf479d0ea..994cc6c57 100644 --- a/ci/requirements/minimum-tests.yml +++ b/ci/requirements/minimum-tests.yml @@ -16,6 +16,7 @@ dependencies: - scipy - xarray>=0.16.1 - xskillscore>=0.0.18 + - mypy_extensions - pip: - pytest-lazy-fixture - -e ../.. diff --git a/climpred/classes.py b/climpred/classes.py index dae72f080..904bc4f63 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -22,10 +22,7 @@ import xarray as xr from dask import is_dask_collection from IPython.display import display_html -from mypy_extensions import ( # (Arg, DefaultArg, NamedArg, DefaultNamedArg, - KwArg, - VarArg, -) +from mypy_extensions import KwArg, VarArg from xarray.core.coordinates import DatasetCoordinates from xarray.core.dataset import DataVariables from xarray.core.formatting_html import dataset_repr From f622f211b75a0f37875b73b511779c8f4177ceef Mon Sep 17 00:00:00 2001 From: AS Date: Mon, 29 Nov 2021 20:35:30 +0100 Subject: [PATCH 14/25] add mypy_extensions to requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 757d137e6..0f8d05979 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ toolz cftime>=1.5.0 xskillscore>=0.0.20 cf_xarray>=0.6.0 +mypy_extensions From 82a9798d45fde469d9ba30a564851724c858a235 Mon Sep 17 00:00:00 2001 From: Aaron Spring Date: Mon, 29 Nov 2021 22:45:02 +0100 Subject: [PATCH 15/25] Apply suggestions from code review --- climpred/graphics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/climpred/graphics.py b/climpred/graphics.py index 18237bac3..67f60576a 100644 --- a/climpred/graphics.py +++ b/climpred/graphics.py @@ -7,7 +7,7 @@ from xarray.coding.times import infer_calendar_name from .checks import DimensionError -from .classes import HindcastEnsemble, PerfectModelEnsemble +# from .classes import HindcastEnsemble, PerfectModelEnsemble from .constants import CLIMPRED_DIMS from .metrics import ALL_METRICS, PROBABILISTIC_METRICS from .utils import get_lead_cftime_shift_args, get_metric_class, shift_cftime_index @@ -211,7 +211,7 @@ def _check_only_climpred_dims(pe): def plot_lead_timeseries_hindcast( - he: HindcastEnsemble, + he: "HindcastEnsemble", variable: Optional[str] = None, ax: Optional[plt.Axes] = None, show_members: bool = False, @@ -306,7 +306,7 @@ def plot_lead_timeseries_hindcast( def plot_ensemble_perfect_model( - pm: PerfectModelEnsemble, + pm: "PerfectModelEnsemble", variable: Optional[str] = None, ax: Optional[plt.Axes] = None, show_members: bool = False, From a01eb31e82f40aeaee69dfac84655a84442316aa Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 10:44:56 +0100 Subject: [PATCH 16/25] "plt.Axes" --- climpred/graphics.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/climpred/graphics.py b/climpred/graphics.py index 67f60576a..fd1587c2f 100644 --- a/climpred/graphics.py +++ b/climpred/graphics.py @@ -7,7 +7,7 @@ from xarray.coding.times import infer_calendar_name from .checks import DimensionError -# from .classes import HindcastEnsemble, PerfectModelEnsemble +from .classes import HindcastEnsemble, PerfectModelEnsemble from .constants import CLIMPRED_DIMS from .metrics import ALL_METRICS, PROBABILISTIC_METRICS from .utils import get_lead_cftime_shift_args, get_metric_class, shift_cftime_index @@ -82,7 +82,7 @@ def plot_relative_entropy(rel_ent, rel_ent_threshold=None, **kwargs): def plot_bootstrapped_skill_over_leadyear( bootstrapped: xr.Dataset, - ax: Optional[plt.Axes] = None, + ax: Optional["plt.Axes"] = None, color_initialized: str = "indianred", color_uninitialized: str = "steelblue", color_persistence: str = "gray", @@ -91,7 +91,7 @@ def plot_bootstrapped_skill_over_leadyear( fontsize: Union[int, float] = 8, figsize: Tuple = (10, 4), fmt: str = "--o", -) -> plt.Axes: +) -> "plt.Axes": """ Plot Ensemble Prediction skill as in Li et al. 2016 Fig.3a-c. @@ -99,7 +99,7 @@ def plot_bootstrapped_skill_over_leadyear( bootstrapped (xr.DataArray or xr.Dataset with one variable): from PredictionEnsembleEnsemble.bootstrap() or HindcastEnsemble.bootstrap() - ax (plt.axes): plot on ax. Defaults to None. + ax ("plt.Axes"): plot on ax. Defaults to None. Returns: ax @@ -211,13 +211,13 @@ def _check_only_climpred_dims(pe): def plot_lead_timeseries_hindcast( - he: "HindcastEnsemble", + he: HindcastEnsemble, variable: Optional[str] = None, - ax: Optional[plt.Axes] = None, + ax: Optional["plt.Axes"] = None, show_members: bool = False, cmap: Optional[str] = "viridis", x: str = "time", -) -> plt.Axes: +) -> "plt.Axes": """Plot datasets from HindcastEnsemble. Args: @@ -306,13 +306,13 @@ def plot_lead_timeseries_hindcast( def plot_ensemble_perfect_model( - pm: "PerfectModelEnsemble", + pm: PerfectModelEnsemble, variable: Optional[str] = None, - ax: Optional[plt.Axes] = None, + ax: Optional["plt.Axes"] = None, show_members: bool = False, cmap: Optional[str] = "tab10", x: str = "time", -) -> plt.Axes: +) -> "plt.Axes": """Plot datasets from PerfectModelEnsemble. Args: From 049e179661a6e9992bc6aa45d489b443b9577145 Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 10:51:46 +0100 Subject: [PATCH 17/25] import graphics later in plot --- climpred/classes.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index 904bc4f63..a00548c28 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -1,6 +1,6 @@ import warnings from copy import deepcopy -from typing import ( # TYPE_CHECKING,; Collection,; DefaultDict,; MutableMapping, +from typing import ( Any, Callable, Collection, @@ -62,7 +62,6 @@ XCLIM_BIAS_CORRECTION_METHODS, ) from .exceptions import DimensionError, VariableError -from .graphics import plot_ensemble_perfect_model, plot_lead_timeseries_hindcast from .logging import log_compute_hindcast_header from .metrics import Metric from .options import OPTIONS @@ -357,11 +356,11 @@ def identical(self, other: Union["PredictionEnsemble", Any]) -> bool: def plot( self, variable: Optional[str] = None, - ax: Optional[plt.Axes] = None, + ax: Optional["plt.Axes"] = None, show_members: bool = False, cmap: Optional[str] = None, x: str = "time", - ) -> plt.Axes: + ) -> "plt.Axes": """Plot datasets from PredictionEnsemble. Args: @@ -388,6 +387,8 @@ def plot( ax: plt.axes """ + from .graphics import plot_ensemble_perfect_model, plot_lead_timeseries_hindcast + if x == "time": x = "valid_time" assert x in ["valid_time", "init"] From 0fad8b1c1e4f6ff3a3449a621878f620f6c85147 Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 11:12:10 +0100 Subject: [PATCH 18/25] fix comparisons --- climpred/tests/test_comparisons.py | 48 ++++++++++++++---------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/climpred/tests/test_comparisons.py b/climpred/tests/test_comparisons.py index a29621655..eb001ddef 100644 --- a/climpred/tests/test_comparisons.py +++ b/climpred/tests/test_comparisons.py @@ -142,35 +142,33 @@ def test_all(PM_da_initialized_1d, comparison, metric): assert set(forecast.dims) - set(["member"]) == set(obs.dims) -def my_m2me_comparison(ds, metric=None): - """Identical to m2e but median.""" - reference_list = [] - forecast_list = [] - supervector_dim = "member" - for m in ds.member.values: - forecast = ds.drop_sel(member=m).median("member") - reference = ds.sel(member=m).squeeze() - forecast_list.append(forecast) - reference_list.append(reference) - reference = xr.concat(reference_list, supervector_dim) - forecast = xr.concat(forecast_list, supervector_dim) - forecast[supervector_dim] = np.arange(forecast[supervector_dim].size) - reference[supervector_dim] = np.arange(reference[supervector_dim].size) - return forecast, reference - - -my_m2me_comparison = Comparison( - name="m2me", - function=my_m2me_comparison, - probabilistic=False, - hindcast=False, -) - - @pytest.mark.parametrize("metric", ("rmse", "pearson_r")) def test_new_comparison_passed_to_compute( PM_da_initialized_1d, PM_da_control_1d, metric ): + def my_m2me_comparison(ds, metric=None): + """Identical to m2e but median.""" + reference_list = [] + forecast_list = [] + supervector_dim = "member" + for m in ds.member.values: + forecast = ds.drop_sel(member=m).median("member") + reference = ds.sel(member=m).squeeze() + forecast_list.append(forecast) + reference_list.append(reference) + reference = xr.concat(reference_list, supervector_dim) + forecast = xr.concat(forecast_list, supervector_dim) + forecast[supervector_dim] = np.arange(forecast[supervector_dim].size) + reference[supervector_dim] = np.arange(reference[supervector_dim].size) + return forecast, reference + + my_m2me_comparison = Comparison( + name="m2me", + function=my_m2me_comparison, + probabilistic=False, + hindcast=False, + ) + actual = compute_perfect_model( PM_da_initialized_1d, PM_da_control_1d, From 07e9867c88a7e4ad81ec22ef6710fe90d32c01c4 Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 11:15:18 +0100 Subject: [PATCH 19/25] plot Any --- climpred/classes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index a00548c28..cb169806c 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -356,11 +356,11 @@ def identical(self, other: Union["PredictionEnsemble", Any]) -> bool: def plot( self, variable: Optional[str] = None, - ax: Optional["plt.Axes"] = None, + ax: Optional[Any] = None, # actually plt.Axes but plt not in requirements show_members: bool = False, cmap: Optional[str] = None, x: str = "time", - ) -> "plt.Axes": + ) -> Any: """Plot datasets from PredictionEnsemble. Args: From b8246cc5dd5daff83e3791236bd0187ff401f45b Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 11:18:43 +0100 Subject: [PATCH 20/25] plot Any --- climpred/classes.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index cb169806c..c0d2556aa 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -351,8 +351,6 @@ def identical(self, other: Union["PredictionEnsemble", Any]) -> bool: return False return id - import matplotlib.pyplot as plt - def plot( self, variable: Optional[str] = None, From f55a8b564313604194fe06c543d7370a64beda01 Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 11:23:26 +0100 Subject: [PATCH 21/25] rm useless in smooth --- climpred/classes.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index c0d2556aa..139a53358 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -731,11 +731,6 @@ def smooth( # could be more robust in how it finds these two spatial dimensions regardless # of name. Optional work in progress comment. elif isinstance(smooth_kws, dict): - non_time_dims = [ - dim for dim in smooth_kws.keys() if dim not in ["time", "lead"] - ] - if len(non_time_dims) > 0: - non_time_dims = non_time_dims[0] # goddard when time_dim and lon/lat given if ("lon" in smooth_kws or "lat" in smooth_kws) and ( "lead" in smooth_kws or "time" in smooth_kws From 79d9b12ebf6406bc0d7ece74c3561109eeef39ac Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 12:23:29 +0100 Subject: [PATCH 22/25] mypy smooth --- climpred/classes.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index 139a53358..0c315887d 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -208,7 +208,7 @@ def __init__(self, xobj: Union[xr.DataArray, xr.Dataset]): # run. self._datasets = {"initialized": xobj, "uninitialized": {}} self.kind = "prediction" - self._temporally_smoothed = None + self._temporally_smoothed: Optional[Dict[str, int]] = None self._is_annual_lead = None self._warn_if_chunked_along_init_member_time() @@ -656,9 +656,9 @@ def get_uninitialized(self) -> xr.Dataset: def smooth( self, - smooth_kws=None, - how="mean", - **xesmf_kwargs, + smooth_kws: Optional[Union[str, Dict[str, int]]] = None, + how: str = "mean", + **xesmf_kwargs: Any, ): """Smooth all entries of PredictionEnsemble in the same manner to be able to still calculate prediction skill afterwards. @@ -711,6 +711,8 @@ def smooth( """ if not smooth_kws: return self + tsmooth_kws: Optional[Union[str, Dict[str, int]]] = None + d_lon_lat_kws: Optional[Union[str, Dict[str, int]]] = None # get proper smoothing function based on smooth args if isinstance(smooth_kws, str): if "goddard" in smooth_kws: @@ -893,7 +895,7 @@ def _apply_climpred_function( func: Callable[..., Any], input_dict: Dict[str, Any], **kwargs: Any, - ): + ) -> Union["PerfectModelEnsemble", xr.Dataset]: """Helper function to loop through observations and apply an arbitrary climpred function. @@ -1450,7 +1452,7 @@ def __init__(self, xobj: Union[xr.DataArray, xr.Dataset]) -> None: def _apply_climpred_function( self, func: Callable[..., Any], init: bool, **kwargs: Any - ) -> "HindcastEnsemble": + ) -> Union["HindcastEnsemble", xr.Dataset]: """Helper function to loop through verification data and apply an arbitrary climpred function. From 93e88bd496df3083259ead8ea4ff6772d9b9b9da Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 12:39:48 +0100 Subject: [PATCH 23/25] get rid of mypy_extensions --- ci/requirements/climpred-dev.yml | 1 - ci/requirements/docs.yml | 1 - ci/requirements/docs_notebooks.yml | 1 - ci/requirements/maximum-tests.yml | 1 - ci/requirements/minimum-tests.yml | 1 - climpred/classes.py | 3 +-- requirements.txt | 1 - 7 files changed, 1 insertion(+), 8 deletions(-) diff --git a/ci/requirements/climpred-dev.yml b/ci/requirements/climpred-dev.yml index 867a4ccec..d9be284f6 100644 --- a/ci/requirements/climpred-dev.yml +++ b/ci/requirements/climpred-dev.yml @@ -37,7 +37,6 @@ dependencies: - pytest - pytest-cov - pytest-sugar - - mypy_extensions # Performance - bottleneck - numba diff --git a/ci/requirements/docs.yml b/ci/requirements/docs.yml index 493830743..a67e573ab 100644 --- a/ci/requirements/docs.yml +++ b/ci/requirements/docs.yml @@ -7,7 +7,6 @@ dependencies: - importlib_metadata - matplotlib-base - nbsphinx - - mypy_extensions - nc-time-axis - netcdf4 - sphinx diff --git a/ci/requirements/docs_notebooks.yml b/ci/requirements/docs_notebooks.yml index 471957643..3c4331dd4 100644 --- a/ci/requirements/docs_notebooks.yml +++ b/ci/requirements/docs_notebooks.yml @@ -18,7 +18,6 @@ dependencies: - sphinxcontrib-napoleon - sphinx_rtd_theme - sphinx-copybutton - - mypy_extensions - toolz - xarray>=0.16.1 - esmtools>=1.1.3 diff --git a/ci/requirements/maximum-tests.yml b/ci/requirements/maximum-tests.yml index c98e0bdd5..609f17beb 100644 --- a/ci/requirements/maximum-tests.yml +++ b/ci/requirements/maximum-tests.yml @@ -17,7 +17,6 @@ dependencies: - pytest - pytest-cov - pytest-xdist - - mypy_extensions - scipy - xarray>=0.16.1 - xesmf diff --git a/ci/requirements/minimum-tests.yml b/ci/requirements/minimum-tests.yml index 994cc6c57..bf479d0ea 100644 --- a/ci/requirements/minimum-tests.yml +++ b/ci/requirements/minimum-tests.yml @@ -16,7 +16,6 @@ dependencies: - scipy - xarray>=0.16.1 - xskillscore>=0.0.18 - - mypy_extensions - pip: - pytest-lazy-fixture - -e ../.. diff --git a/climpred/classes.py b/climpred/classes.py index 0c315887d..32880cdc3 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -22,7 +22,6 @@ import xarray as xr from dask import is_dask_collection from IPython.display import display_html -from mypy_extensions import KwArg, VarArg from xarray.core.coordinates import DatasetCoordinates from xarray.core.dataset import DataVariables from xarray.core.formatting_html import dataset_repr @@ -526,7 +525,7 @@ def sel_vars(ds, varlist): return self._apply_func(sel_vars, varlist) - def __getattr__(self, name: str) -> Callable[[VarArg(Any), KwArg(Any)], Any]: + def __getattr__(self, name): """Allows for xarray methods to be applied to our prediction objects. Args: diff --git a/requirements.txt b/requirements.txt index 0f8d05979..757d137e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,3 @@ toolz cftime>=1.5.0 xskillscore>=0.0.20 cf_xarray>=0.6.0 -mypy_extensions From ce3d836bfe4205665d204243b423d9957bad02b4 Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 12:54:08 +0100 Subject: [PATCH 24/25] mypy plot --- climpred/classes.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/climpred/classes.py b/climpred/classes.py index 32880cdc3..411131137 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -91,6 +91,15 @@ groupbyType = Optional[Union[str, xr.DataArray]] metric_kwargsType = Optional[Any] +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import matplotlib.pyplot as plt + + optionalaxisType = Optional[plt.Axes] +else: + optionalaxisType = Optional[Any] + def _display_metadata(self) -> str: """ @@ -353,11 +362,11 @@ def identical(self, other: Union["PredictionEnsemble", Any]) -> bool: def plot( self, variable: Optional[str] = None, - ax: Optional[Any] = None, # actually plt.Axes but plt not in requirements + ax: optionalaxisType = None, # actually plt.Axes but plt not in requirements show_members: bool = False, cmap: Optional[str] = None, x: str = "time", - ) -> Any: + ) -> "plt.Axes": """Plot datasets from PredictionEnsemble. Args: From 1139fa909688431028f3b460894a263ac066f2b7 Mon Sep 17 00:00:00 2001 From: AS Date: Tue, 30 Nov 2021 13:17:59 +0100 Subject: [PATCH 25/25] finalize --- CHANGELOG.rst | 4 +--- climpred/classes.py | 20 +++++++++----------- climpred/metrics.py | 4 +--- 3 files changed, 11 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index e37bc121d..151efa69e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -45,12 +45,10 @@ New Features :py:meth:`~climpred.classes.PerfectModelEnsemble.bootstrap` to group skill by initializations seasonality. (:issue:`635`, :pr:`690`) `Aaron Spring`_. -Documentation -------------- - Internals/Minor Fixes --------------------- - Reduce dependencies (:pr:`686`) `Aaron Spring`_. +- Add `typing `_ (:issue:`685`, :pr:`692`) `Aaron Spring`_. climpred v2.1.6 (2021-08-31) diff --git a/climpred/classes.py b/climpred/classes.py index 411131137..85e4e4914 100644 --- a/climpred/classes.py +++ b/climpred/classes.py @@ -3,16 +3,12 @@ from typing import ( Any, Callable, - Collection, Dict, Hashable, - Iterable, Iterator, List, Mapping, Optional, - Sequence, - Set, Tuple, Union, ) @@ -362,7 +358,7 @@ def identical(self, other: Union["PredictionEnsemble", Any]) -> bool: def plot( self, variable: Optional[str] = None, - ax: optionalaxisType = None, # actually plt.Axes but plt not in requirements + ax: optionalaxisType = None, show_members: bool = False, cmap: Optional[str] = None, x: str = "time", @@ -534,7 +530,9 @@ def sel_vars(ds, varlist): return self._apply_func(sel_vars, varlist) - def __getattr__(self, name): + def __getattr__( + self, name: str + ) -> Callable: # -> Callable[[VarArg(Any), KwArg(Any)], Any] """Allows for xarray methods to be applied to our prediction objects. Args: @@ -1311,7 +1309,11 @@ def bootstrap( If None or empty, returns no p value. iterations (int): Number of resampling iterations for bootstrapping with replacement. Recommended >= 500. - resample_dim (str): dimension for resampling + resample_dim (str or list): dimension to resample from. default: 'member'. + + - 'member': select a different set of members from hind + - 'init': select a different set of initializations from hind + sig (int, default 95): Significance level in percent for deciding whether uninitialized and persistence beat initialized skill. pers_sig (int): If not ``None``, the separate significance level for @@ -1641,10 +1643,6 @@ def verify( Data variables: SST (skill, lead) float64 0.9023 0.8807 0.8955 ... 0.9078 0.9128 0.9159 """ - if isinstance(reference, str): - reference = [reference] - else: - pass # reference = list(reference) if groupby is not None: skill_group = [] group_label = [] diff --git a/climpred/metrics.py b/climpred/metrics.py index a0f0cee10..ec9c0ec42 100644 --- a/climpred/metrics.py +++ b/climpred/metrics.py @@ -1,4 +1,5 @@ import warnings +from typing import Callable, List, Optional import numpy as np import pandas as pd @@ -189,9 +190,6 @@ def _display_metric_metadata(self) -> str: return summary -from typing import Callable, List, Optional - - class Metric: """Master class for all metrics."""