diff --git a/clouddrift/adapters/gdp/__init__.py b/clouddrift/adapters/gdp/__init__.py index e692aea3..ac01a6fc 100644 --- a/clouddrift/adapters/gdp/__init__.py +++ b/clouddrift/adapters/gdp/__init__.py @@ -7,10 +7,12 @@ import os import tempfile +import typing import numpy as np import pandas as pd import xarray as xr +from numpy.typing import NDArray from clouddrift.adapters.utils import download_with_progress from clouddrift.raggedarray import DimNames @@ -269,7 +271,9 @@ def str_to_float(value: str, default: float = np.nan) -> float: return default -def cut_str(value: str, max_length: int) -> np.chararray: +def cut_str( + value: str, max_length: int +) -> np.chararray[typing.Any, np.dtype[np.bytes_]]: """Cut a string to a specific length and return it as a numpy chararray. Parameters @@ -289,7 +293,7 @@ def cut_str(value: str, max_length: int) -> np.chararray: return charar -def drogue_presence(lost_time, time) -> np.ndarray: +def drogue_presence(lost_time, time) -> NDArray[typing.Any]: """Create drogue status from the drogue lost time and the trajectory time. Parameters diff --git a/clouddrift/adapters/gdp/gdpsource.py b/clouddrift/adapters/gdp/gdpsource.py index 595f533f..793172a0 100644 --- a/clouddrift/adapters/gdp/gdpsource.py +++ b/clouddrift/adapters/gdp/gdpsource.py @@ -5,14 +5,15 @@ import logging import os import tempfile +import typing import warnings from collections import defaultdict from concurrent.futures import Future, ProcessPoolExecutor, as_completed -from typing import Callable import numpy as np import pandas as pd import xarray as xr +from numpy.typing import NDArray from tqdm.asyncio import tqdm from clouddrift.adapters.gdp import get_gdp_metadata @@ -56,7 +57,7 @@ "death_code", ] -_VARS_FILL_MAP: dict = { +_VARS_FILL_MAP: dict[str, int | str | np.datetime64] = { "wmo_number": -999, "program_number": -999, "buoys_type": "N/A", @@ -70,7 +71,7 @@ "death_code": -999, } -_VAR_DTYPES: dict = { +_VAR_DTYPES: dict[str, type | np.dtype[np.datetime64]] = { "rowsize": np.int64, "wmo_number": np.int64, "program_number": np.int64, @@ -126,7 +127,7 @@ } -VARS_ATTRS: dict = { +VARS_ATTRS: dict[str, dict[str, str]] = { "id": {"long_name": "Global Drifter Program Buoy ID", "units": "-"}, "rowsize": { "long_name": "Number of observations per trajectory", @@ -334,7 +335,9 @@ def _preprocess(id_, **kwargs) -> xr.Dataset: return dataset -def _apply_remove(df: pd.DataFrame, filters: list[Callable]) -> pd.DataFrame: +def _apply_remove( + df: pd.DataFrame, filters: list[typing.Callable[..., typing.Any]] +) -> pd.DataFrame: temp_df = df for filter_ in filters: mask = filter_(temp_df) @@ -344,7 +347,7 @@ def _apply_remove(df: pd.DataFrame, filters: list[Callable]) -> pd.DataFrame: def _apply_transform( df: pd.DataFrame, - transforms: dict[str, tuple[list[str], Callable]], + transforms: dict[str, tuple[list[str], typing.Callable[..., typing.Any]]], ) -> pd.DataFrame: tmp_df = df for output_col in transforms.keys(): @@ -359,8 +362,10 @@ def _apply_transform( def _parse_datetime_with_day_ratio( - month_series: np.ndarray, day_series: np.ndarray, year_series: np.ndarray -) -> np.ndarray: + month_series: NDArray[np.float32], + day_series: NDArray[np.float32], + year_series: NDArray[np.float32], +) -> NDArray[np.datetime64]: values = list() for month, day_with_ratio, year in zip(month_series, day_series, year_series): day = day_with_ratio // 1 @@ -479,7 +484,7 @@ def _combine_chunked_drifter_datasets(datasets: list[xr.Dataset]) -> xr.Dataset: ) sort_coord = traj_dataset.coords["obs_index"] - vals: np.ndarray = sort_coord.data + vals: NDArray[np.int64] = sort_coord.data sort_coord_dim = sort_coord.dims[-1] sort_key = vals.argsort() @@ -531,8 +536,8 @@ async def _parallel_get( chunksize=chunk_size, ) - joblist = list[Future]() - jobmap = dict[Future, pd.DataFrame]() + joblist = list[Future[dict[int, xr.Dataset]]]() + jobmap = dict[Future[dict[int, xr.Dataset]], pd.DataFrame]() for idx, chunk in enumerate(file_chunks): if max_chunks is not None and idx >= max_chunks: break @@ -568,7 +573,7 @@ async def _parallel_get( drifter_chunked_datasets[id_].append(drifter_ds) bar.update() - combine_jobmap = dict[Future, int]() + combine_jobmap = dict[Future[xr.Dataset], int]() for id_ in drifter_chunked_datasets.keys(): datasets = drifter_chunked_datasets[id_] diff --git a/clouddrift/adapters/hurdat2.py b/clouddrift/adapters/hurdat2.py index da1166d3..a9f3cab3 100644 --- a/clouddrift/adapters/hurdat2.py +++ b/clouddrift/adapters/hurdat2.py @@ -390,7 +390,9 @@ def to_raggedarray( preprocess_func=lambda idx: track_data[idx].to_xarray_dataset(), attrs_global=TrackData.global_attrs, attrs_variables={ - field.name: field.metadata + field.name: dict( + field.metadata + ) # type cast needed for static type analysis for field in fields(HeaderLine) + fields(DataLine) }, ) diff --git a/clouddrift/adapters/utils.py b/clouddrift/adapters/utils.py index 26a61159..9d17ceda 100644 --- a/clouddrift/adapters/utils.py +++ b/clouddrift/adapters/utils.py @@ -1,15 +1,14 @@ import concurrent.futures import logging import os +import typing import urllib from datetime import datetime from io import BufferedIOBase, BufferedWriter -from typing import Callable, Sequence import requests from tenacity import ( RetryCallState, - WrappedFn, retry, retry_if_exception, stop_after_attempt, @@ -30,7 +29,10 @@ def _before_call(rcs: RetryCallState): _CHUNK_SIZE = 1_048_576 # 1MiB _logger = logging.getLogger(__name__) -standard_retry_protocol: Callable[[WrappedFn], WrappedFn] = retry( +_Func = typing.Callable[..., typing.Any] +_Wrapper = typing.Callable[[_Func], _Func] + +standard_retry_protocol: _Wrapper = retry( retry=retry_if_exception( lambda ex: isinstance( ex, @@ -51,19 +53,19 @@ def _before_call(rcs: RetryCallState): def download_with_progress( - download_map: Sequence[ + download_map: typing.Sequence[ tuple[str, BufferedIOBase | str] | tuple[str, BufferedIOBase | str, float] ], show_list_progress: bool | None = None, desc: str = "Downloading files", - custom_retry_protocol: Callable[[WrappedFn], WrappedFn] | None = None, + custom_retry_protocol: _Wrapper | None = None, ): if show_list_progress is None: show_list_progress = len(download_map) > 20 if custom_retry_protocol is None: retry_protocol = standard_retry_protocol else: - retry_protocol = custom_retry_protocol # type: ignore + retry_protocol = custom_retry_protocol executor = concurrent.futures.ThreadPoolExecutor() futures: dict[ @@ -156,10 +158,10 @@ def _download_with_progress( ) _logger.debug(f"Downloading from {url} to {output}...") - bar = None - with requests.get(url, timeout=5, stream=True) as response: buffer: BufferedWriter | BufferedIOBase | None = None + bar = None + try: if isinstance(output, (str,)): buffer = open(output, "wb") @@ -179,6 +181,7 @@ def _download_with_progress( nrows=2, disable=_DISABLE_SHOW_PROGRESS, ) + for chunk in response.iter_content(_CHUNK_SIZE): if not chunk: break @@ -186,8 +189,6 @@ def _download_with_progress( if bar is not None: bar.update(len(chunk)) finally: - if response is not None: - response.close() if bar is not None: bar.close() if buffer is not None and isinstance(output, (str,)): diff --git a/clouddrift/kinematics.py b/clouddrift/kinematics.py index 222db123..43786543 100644 --- a/clouddrift/kinematics.py +++ b/clouddrift/kinematics.py @@ -3,8 +3,8 @@ """ import numpy as np -import pandas as pd import xarray as xr +from numpy.typing import NDArray from clouddrift.sphere import ( EARTH_RADIUS_METERS, @@ -17,13 +17,14 @@ recast_lon360, spherical_to_cartesian, ) +from clouddrift.typing import ArrayTypes from clouddrift.wavelet import morse_logspace_freq, morse_wavelet, wavelet_transform def kinetic_energy( - u: float | list | np.ndarray | xr.DataArray | pd.Series, - v: float | list | np.ndarray | xr.DataArray | pd.Series | None = None, -) -> float | np.ndarray | xr.DataArray: + u: float | ArrayTypes, + v: float | ArrayTypes | None = None, +) -> float | NDArray[np.float64] | xr.DataArray: """Compute kinetic energy from zonal and meridional velocities. Parameters @@ -531,7 +532,7 @@ def velocity_from_position( coord_system: str = "spherical", difference_scheme: str = "forward", time_axis: int = -1, -) -> tuple[xr.DataArray, xr.DataArray]: +) -> tuple[np.ndarray, np.ndarray]: """Compute velocity from arrays of positions and time. x and y can be provided as longitude and latitude in degrees if diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index f08a3c9e..2c915bd4 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -2,6 +2,7 @@ Transformational and inquiry functions for ragged arrays. """ +import typing import warnings from collections.abc import Callable, Iterable from concurrent import futures @@ -10,18 +11,25 @@ import numpy as np import pandas as pd import xarray as xr +from numpy.typing import NDArray + +from clouddrift.typing import ArrayTypes + +_ArrayOutput = typing.TypeVar( + "_ArrayOutput", bound=tuple[np.ndarray, np.ndarray] | np.ndarray +) def apply_ragged( - func: callable, - arrays: list[np.ndarray | xr.DataArray] | np.ndarray | xr.DataArray, - rowsize: list[int] | np.ndarray[int] | xr.DataArray, - *args: tuple, + func: Callable[..., _ArrayOutput], + arrays: ArrayTypes, + rowsize: ArrayTypes, + *args: typing.Any, rows: int | Iterable[int] = None, axis: int = 0, executor: futures.Executor = futures.ThreadPoolExecutor(max_workers=None), - **kwargs: dict, -) -> tuple[np.ndarray] | np.ndarray: + **kwargs: typing.Any, +) -> _ArrayOutput: """Apply a function to a ragged array. The function ``func`` will be applied to each contiguous row of ``arrays`` as @@ -423,7 +431,7 @@ def regular_to_ragged( return array[valid], np.sum(valid, axis=1) -def rowsize_to_index(rowsize: list | np.ndarray | xr.DataArray) -> np.ndarray: +def rowsize_to_index(rowsize: ArrayTypes) -> np.ndarray: """Convert a list of row sizes to a list of indices. This function is typically used to obtain the indices of data rows organized @@ -450,10 +458,10 @@ def rowsize_to_index(rowsize: list | np.ndarray | xr.DataArray) -> np.ndarray: def segment( - x: np.ndarray, + x: ArrayTypes, tolerance: float | np.timedelta64 | timedelta | pd.Timedelta, - rowsize: np.ndarray[int] = None, -) -> np.ndarray[int]: + rowsize: NDArray[np.int64] | None = None, +) -> NDArray[np.int64]: """Divide an array into segments based on a tolerance value. Parameters @@ -787,11 +795,11 @@ def subset( def unpack( - ragged_array: np.ndarray, - rowsize: np.ndarray[int], - rows: int | Iterable[int] = None, + ragged_array: ArrayTypes, + rowsize: ArrayTypes, + rows: int | np.int64 | Iterable[int] | None = None, axis: int = 0, -) -> list[np.ndarray]: +) -> list[NDArray[typing.Any]]: """Unpack a ragged array into a list of regular arrays. Unpacking a ``np.ndarray`` ragged array is about 2 orders of magnitude diff --git a/clouddrift/raggedarray.py b/clouddrift/raggedarray.py index a7aadb06..529d3334 100644 --- a/clouddrift/raggedarray.py +++ b/clouddrift/raggedarray.py @@ -13,9 +13,11 @@ import awkward as ak # type: ignore import numpy as np import xarray as xr +from numpy.typing import NDArray from tqdm import tqdm from clouddrift.ragged import rowsize_to_index +from clouddrift.typing import ArrayTypes DimNames = Literal["rows", "obs"] _DISABLE_SHOW_PROGRESS = False # purely to de-noise our test suite output, should never be used/configured outside of that. @@ -24,11 +26,11 @@ class RaggedArray: def __init__( self, - coords: dict, - metadata: dict, - data: dict, - attrs_global: dict = {}, - attrs_variables: dict = {}, + coords: dict[str, NDArray[Any]], + metadata: dict[str, NDArray[Any]], + data: dict[str, NDArray[Any]], + attrs_global: dict[str, str] = {}, + attrs_variables: dict[str, dict[str, str]] = {}, name_dims: dict[str, DimNames] = {}, coord_dims: dict[str, str] = {}, ): @@ -46,7 +48,7 @@ def __init__( def from_awkward( cls, array: ak.Array, - name_coords: list, + name_coords: list[str], name_dims: dict[str, DimNames], coord_dims: dict[str, str], ): @@ -99,15 +101,15 @@ def from_awkward( @classmethod def from_files( cls, - indices: list[int], - preprocess_func: Callable[[int], xr.Dataset], - name_coords: list, - name_meta: list = list(), - name_data: list = list(), + indices: ArrayTypes, + preprocess_func: Callable[[Any], xr.Dataset], + name_coords: list[str], + name_meta: list[str] = list(), + name_data: list[str] = list(), name_dims: dict[str, DimNames] = {}, rowsize_func: Callable[[int], int] | None = None, - attrs_global: dict | None = None, - attrs_variables: dict | None = None, + attrs_global: dict[str, str] | None = None, + attrs_variables: dict[str, dict[str, str]] | None = None, **kwargs, ): """Generate a ragged array archive from a list of files @@ -189,7 +191,7 @@ def from_netcdf(cls, filename: str, rows_dim_name="rows", obs_dim_name="obs"): def from_parquet( cls, filename: str, - name_coords: list, + name_coords: list[str], name_dims: dict[str, DimNames], coord_dims: dict[str, str], ): @@ -235,9 +237,9 @@ def from_xarray( RaggedArray A RaggedArray instance """ - coords = {} - metadata = {} - data = {} + coords: dict[str, NDArray[Any]] = {} + metadata: dict[str, NDArray[Any]] = {} + data: dict[str, NDArray[Any]] = {} coord_dims = {} name_dims: dict[str, DimNames] = {rows_dim_name: "rows", obs_dim_name: "obs"} attrs_global = {} @@ -253,6 +255,7 @@ def from_xarray( attrs_variables[var] = ds[var].attrs for var in ds.data_vars.keys(): + var = str(var) if len(ds[var]) == ds.sizes.get(rows_dim_name): metadata[var] = ds[var].data elif len(ds[var]) == ds.sizes.get(obs_dim_name): @@ -273,8 +276,8 @@ def from_xarray( @staticmethod def number_of_observations( - rowsize_func: Callable[[int], int], indices: list, **kwargs - ) -> np.ndarray: + rowsize_func: Callable[[int], int], indices: ArrayTypes, **kwargs + ) -> NDArray[np.int64]: """Iterate through the files and evaluate the number of observations. Parameters @@ -305,10 +308,10 @@ def number_of_observations( @staticmethod def attributes( ds: xr.Dataset, - name_coords: list, - name_meta: list, - name_data: list, - ) -> tuple[dict, dict]: + name_coords: list[str], + name_meta: list[str], + name_data: list[str], + ) -> tuple[dict[str, str], dict[str, dict[str, str]]]: """Return global attributes and the attributes of all variables (name_coords, name_meta, and name_data) from an Xarray Dataset. @@ -342,15 +345,20 @@ def attributes( @staticmethod def allocate( - preprocess_func: Callable[[int], xr.Dataset], - indices: list, - rowsize: list | np.ndarray | xr.DataArray, - name_coords: list, - name_meta: list, - name_data: list, + preprocess_func: Callable[[Any], xr.Dataset], + indices: ArrayTypes, + rowsize: ArrayTypes, + name_coords: list[str], + name_meta: list[str], + name_data: list[str], name_dims: dict[str, DimNames], **kwargs, - ) -> tuple[dict, dict, dict, dict]: + ) -> tuple[ + dict[str, NDArray[Any]], + dict[str, NDArray[Any]], + dict[str, NDArray[Any]], + dict[str, str], + ]: """ Iterate through the files and fill for the ragged array associated with coordinates, and selected metadata and data variables. @@ -394,7 +402,7 @@ def allocate( coords = {} coord_dims: dict[str, str] = {} for var in name_coords: - dim = ds[var].dims[-1] + dim = str(ds[var].dims[-1]) dim_size = dim_sizes[dim] coords[var] = np.zeros(dim_size, dtype=ds[var].dtype) coord_dims[var] = dim @@ -427,7 +435,7 @@ def allocate( oid = index_traj[i] for var in name_coords: - dim = ds[var].dims[-1] + dim = str(ds[var].dims[-1]) if name_dims[dim] == "obs": coords[var][oid : oid + size] = ds[var].data else: diff --git a/clouddrift/sphere.py b/clouddrift/sphere.py index 510efaaa..9be6c777 100644 --- a/clouddrift/sphere.py +++ b/clouddrift/sphere.py @@ -3,6 +3,7 @@ """ import warnings +from typing import TypeVar import numpy as np import xarray as xr @@ -207,7 +208,10 @@ def bearing( def position_from_distance_and_bearing( - lon: float, lat: float, distance: float, bearing: float + lon: float | np.ndarray, + lat: float | np.ndarray, + distance: float | np.ndarray, + bearing: float | np.ndarray, ) -> tuple[float, float]: """Return elementwise new position in degrees from arrays of latitude and longitude in degrees, distance in meters, and bearing in radians, based on @@ -660,13 +664,13 @@ def cartesian_to_spherical( return lon, lat +T = TypeVar("T", bound=float | np.ndarray) +V = TypeVar("V", bound=float | np.ndarray) + + def cartesian_to_tangentplane( - u: float | np.ndarray, - v: float | np.ndarray, - w: float | np.ndarray, - longitude: float | np.ndarray, - latitude: float | np.ndarray, -) -> tuple[float, float] | tuple[np.ndarray, np.ndarray]: + u: T, v: T, w: T, longitude: V, latitude: V +) -> tuple[T, T]: """ Project a three-dimensional Cartesian vector on a plane tangent to a spherical Earth. @@ -725,12 +729,12 @@ def cartesian_to_tangentplane( return u_projected, v_projected +T = TypeVar("T", bound=float | np.ndarray) + + def tangentplane_to_cartesian( - up: float | np.ndarray, - vp: float | np.ndarray, - longitude: float | np.ndarray, - latitude: float | np.ndarray, -) -> tuple[float, float, float] | tuple[np.ndarray, np.ndarray, np.ndarray]: + up: T, vp: T, longitude: T, latitude: T +) -> tuple[T, T, T]: """ Return the three-dimensional Cartesian components of a vector contained in a plane tangent to a spherical Earth. diff --git a/clouddrift/typing.py b/clouddrift/typing.py new file mode 100644 index 00000000..3c367665 --- /dev/null +++ b/clouddrift/typing.py @@ -0,0 +1,20 @@ +import datetime +from typing import TYPE_CHECKING, Any, TypeAlias + +import numpy as np +import pandas as pd +import xarray as xr +from numpy.typing import NDArray + +# Subscripting the type for pandas series only works at type checking time +if TYPE_CHECKING: + _SupportedArrayTypes = list[Any] | NDArray[Any] | pd.Series[Any] | xr.DataArray +else: + _SupportedArrayTypes = list[Any] | NDArray[Any] | pd.Series | xr.DataArray + +ArrayTypes: TypeAlias = _SupportedArrayTypes + +_SupportedToleranceTypes = pd.Timedelta | datetime.timedelta | np.timedelta64 +ToleranceTypes: TypeAlias = _SupportedToleranceTypes + +__all__ = ["ArrayTypes", "ToleranceTypes"] diff --git a/clouddrift/wavelet.py b/clouddrift/wavelet.py index 81c7237d..b0ce74cb 100644 --- a/clouddrift/wavelet.py +++ b/clouddrift/wavelet.py @@ -15,22 +15,53 @@ not the MATLAB implementation is licensed under CloudDrift's MIT license. """ +from typing import Any, Literal, overload + import numpy as np +from numpy.typing import NDArray from scipy.special import gamma as _gamma from scipy.special import gammaln as _lgamma +@overload def morse_wavelet_transform( - x: np.ndarray, + x: NDArray[Any], gamma: float, beta: float, - radian_frequency: np.ndarray, + radian_frequency: NDArray[Any], + complex: Literal[True], + order: int = 1, + normalization: str = "bandpass", + boundary: str = "mirror", + time_axis: int = -1, +) -> tuple[NDArray[Any], NDArray[Any]]: ... + + +@overload +def morse_wavelet_transform( + x: NDArray[Any], + gamma: float, + beta: float, + radian_frequency: NDArray[Any], + complex: Literal[False], + order: int = 1, + normalization: str = "bandpass", + boundary: str = "mirror", + time_axis: int = -1, +) -> NDArray[Any]: ... + + +def morse_wavelet_transform( + x: NDArray[Any], + gamma: float, + beta: float, + radian_frequency: NDArray[Any], complex: bool = False, order: int = 1, normalization: str = "bandpass", boundary: str = "periodic", time_axis: int = -1, -) -> tuple[np.ndarray, np.ndarray] | np.ndarray: +) -> tuple[NDArray[Any], NDArray[Any]] | NDArray[Any]: """ Apply a continuous wavelet transform to an input signal using the generalized Morse wavelets of Olhede and Walden (2002). The wavelet transform is normalized differently @@ -185,18 +216,9 @@ def morse_wavelet_transform( wtx_n = wavelet_transform( np.conj(x / np.sqrt(2)), wavelet, boundary=boundary, time_axis=time_axis ) - wtx = wtx_p, wtx_n + return wtx_p, wtx_n - elif not complex: - # real case - wtx = wavelet_transform(x, wavelet, boundary=boundary, time_axis=time_axis) - - else: - raise ValueError( - "`complex` optional argument must be boolean 'True' or 'False'" - ) - - return wtx + return wavelet_transform(x, wavelet, boundary=boundary, time_axis=time_axis) def wavelet_transform( @@ -490,7 +512,7 @@ def _morse_wavelet_first_family( fact: float, gamma: float, beta: float, - norm_radian_frequency: np.ndarray, + norm_radian_frequency: NDArray[Any], wavezero: np.ndarray, order: int = 1, normalization: str = "bandpass", diff --git a/pyproject.toml b/pyproject.toml index dbd5d3b2..4581fe08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,22 @@ select = ["E4", "E7", "E9", "F", "I"] [tool.mypy] python_version = "3.10" follow_imports = "normal" +check_untyped_defs = true +warn_unused_configs = true +warn_redundant_casts =true +warn_unused_ignores = true +strict_equality = true +extra_checks = true +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_any_generics = true +no_implicit_reexport = true + +disallow_untyped_calls = false +disallow_incomplete_defs = false +disallow_untyped_defs = false +warn_return_any = false + files = [ "clouddrift/**/*.py", "tests/**/*.py", @@ -90,7 +106,9 @@ module = [ "clouddrift.signal", "clouddrift.sphere", "clouddrift.wavelet", + "clouddrift.transfer", "tests.kinematics_tests", + "tests.transfer_tests", "tests.pairs_tests", "tests.plotting_tests", ] diff --git a/tests/adapters/gdp/gdp1h_integ_tests.py b/tests/adapters/gdp/gdp1h_integ_tests.py index 38053fc3..265d00d9 100644 --- a/tests/adapters/gdp/gdp1h_integ_tests.py +++ b/tests/adapters/gdp/gdp1h_integ_tests.py @@ -47,7 +47,5 @@ def test_load_subset_and_create_aggregate(self): @classmethod def tearDownClass(cls): - [ + for dir in [gdp1h.GDP_TMP_PATH, gdp1h.GDP_TMP_PATH_EXPERIMENTAL]: shutil.rmtree(dir) - for dir in [gdp1h.GDP_TMP_PATH, gdp1h.GDP_TMP_PATH_EXPERIMENTAL] - ] diff --git a/tests/adapters/gdp/gdp6h_integ_tests.py b/tests/adapters/gdp/gdp6h_integ_tests.py index fd98ad26..f8a92fe8 100644 --- a/tests/adapters/gdp/gdp6h_integ_tests.py +++ b/tests/adapters/gdp/gdp6h_integ_tests.py @@ -24,4 +24,4 @@ def test_load_subset_and_create_aggregate(self): @classmethod def tearDownClass(cls): - [shutil.rmtree(dir) for dir in [gdp6h.GDP_TMP_PATH]] + shutil.rmtree(gdp6h.GDP_TMP_PATH) diff --git a/tests/adapters/gdp/source_integ_tests.py b/tests/adapters/gdp/source_integ_tests.py index 7d91bb2d..6670dc4f 100644 --- a/tests/adapters/gdp/source_integ_tests.py +++ b/tests/adapters/gdp/source_integ_tests.py @@ -44,4 +44,4 @@ def test_load_and_create_aggregate(self): @classmethod def tearDownClass(cls): - [shutil.rmtree(dir) for dir in [gdp_source._TMP_PATH]] + shutil.rmtree(gdp_source._TMP_PATH) diff --git a/tests/adapters/hurdat2_integ_tests.py b/tests/adapters/hurdat2_integ_tests.py index bb826001..f2a809b0 100644 --- a/tests/adapters/hurdat2_integ_tests.py +++ b/tests/adapters/hurdat2_integ_tests.py @@ -39,4 +39,4 @@ def test_conversion(self): @classmethod def tearDownClass(cls): - [shutil.rmtree(dir) for dir in [hurdat2._DEFAULT_FILE_PATH]] + shutil.rmtree(hurdat2._DEFAULT_FILE_PATH) diff --git a/tests/adapters/utils.py b/tests/adapters/utils.py index 7772ed2b..7eb3f4b5 100644 --- a/tests/adapters/utils.py +++ b/tests/adapters/utils.py @@ -3,13 +3,14 @@ class MultiPatcher: - _patches: Sequence[_patch] + _patches: Sequence[_patch] # type: ignore - def __init__(self, patches: Sequence[_patch]): + def __init__(self, patches: Sequence[_patch]): # type: ignore self._patches = patches def __enter__(self) -> Sequence[Mock]: return [p.start() for p in self._patches] def __exit__(self, *_): - [p.stop() for p in self._patches] + for p in self._patches: + p.stop() diff --git a/tests/adapters/utils_tests.py b/tests/adapters/utils_tests.py index 41da65f6..c417c057 100644 --- a/tests/adapters/utils_tests.py +++ b/tests/adapters/utils_tests.py @@ -28,7 +28,6 @@ def setUp(self) -> None: self.requests_mock = Mock() self.requests_mock.head = Mock(return_value=self.head_response_mock) self.requests_mock.get = Mock(return_value=self.get_response_mock) - self.open_mock = mock_open() self.bar_mock = Mock() @@ -57,9 +56,7 @@ def test_forgo_download_no_update(self): ), ] ) as _: - utils._download_with_progress( - "some.url.com", "./some/path/existing-file.nc", 0, False - ) + utils._download_with_progress("some.url.com", "/some/path/here", 0, False) self.requests_mock.get.assert_not_called() def test_download_new_update(self): @@ -129,7 +126,7 @@ def test_progress_mechanism_enabled_files(self): """ mocked_futures = [self.gen_future_mock() for _ in range(0, 21)] - download_requests = [("src0", "dst", None) for _ in range(0, 21)] + download_requests = [("src0", "dst") for _ in range(0, 21)] tpe_mock = Mock() tpe_mock.__enter__ = Mock(return_value=tpe_mock) diff --git a/tests/datasets_tests.py b/tests/datasets_tests.py index 6d363989..b5c721a6 100644 --- a/tests/datasets_tests.py +++ b/tests/datasets_tests.py @@ -1,4 +1,7 @@ +import typing + import numpy as np +from numpy.typing import NDArray import tests.utils as testutils from clouddrift import datasets @@ -36,8 +39,12 @@ def test_glad_subset_and_apply_ragged_work(self): row_dim_name="traj", ) self.assertTrue(ds_sub) - mean_lon = apply_ragged(np.mean, [ds_sub.longitude], ds_sub.rowsize) - self.assertTrue(mean_lon.size == 2) + mean_lon = apply_ragged(self._mean, [ds_sub.longitude], ds_sub.rowsize) + self.assertTrue(len(mean_lon) == 2) + + # For static typing purposes + def _mean(self, x: NDArray[typing.Any]) -> NDArray[typing.Any]: + return np.mean(x) def test_spotters_opens(self): with datasets.spotters() as ds: diff --git a/tests/ragged_tests.py b/tests/ragged_tests.py index 94db576f..ea8f7864 100644 --- a/tests/ragged_tests.py +++ b/tests/ragged_tests.py @@ -1,7 +1,7 @@ +import typing import unittest from concurrent import futures from datetime import datetime, timedelta -from typing import Any import numpy as np import pandas as pd @@ -20,6 +20,7 @@ unpack, ) from clouddrift.raggedarray import RaggedArray +from clouddrift.typing import ArrayTypes, ToleranceTypes if __name__ == "__main__": unittest.main() @@ -40,9 +41,12 @@ def sample_ragged_array() -> RaggedArray: "title": "test trajectories", "history": "version xyz", } - coords: dict[str, list] = {"id": drifter_id, "time": t} metadata = {"rowsize": rowsize} - data: dict[str, list] = {"test": test, "lat": latitude, "lon": longitude} + data: dict[str, list[list[int]] | list[list[bool]]] = { + "test": test, + "lat": latitude, + "lon": longitude, + } # append xr.Dataset to a list list_ds = [] @@ -50,16 +54,16 @@ def sample_ragged_array() -> RaggedArray: xr_coords = {} xr_coords["id"] = ( ["rows"], - [coords["id"][i]], + [drifter_id[i]], {"long_name": "variable id", "units": "-"}, ) xr_coords["time"] = ( ["obs"], - coords["time"][i], + t[i], {"long_name": "variable time", "units": "-"}, ) - xr_data: dict[str, Any] = {} + xr_data: dict[str, typing.Any] = {} for var in metadata.keys(): xr_data[var] = ( ["traj"], @@ -207,11 +211,13 @@ def test_chunk_array_like(self): class prune_tests(unittest.TestCase): def test_prune(self): - x = [1, 2, 3, 1, 2, 1, 2, 3, 4] + x: list[int] = [1, 2, 3, 1, 2, 1, 2, 3, 4] rowsize = [3, 2, 4] minimum = 3 - for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]: + for data in list[ArrayTypes]( + [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)] + ): x_new, rowsize_new = prune(data, rowsize, minimum) self.assertTrue(isinstance(x_new, np.ndarray)) self.assertTrue(isinstance(rowsize_new, np.ndarray)) @@ -223,7 +229,10 @@ def test_prune_all_longer(self): rowsize = [3, 2, 4] minimum = 1 - for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]: + data: ArrayTypes + for data in list[ArrayTypes]( + [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)] + ): x_new, rowsize_new = prune(data, rowsize, minimum) np.testing.assert_equal(x_new, data) np.testing.assert_equal(rowsize_new, rowsize) @@ -233,7 +242,10 @@ def test_prune_all_smaller(self): rowsize = [3, 2, 4] minimum = 5 - for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]: + data: ArrayTypes + for data in list[ArrayTypes]( + [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)] + ): x_new, rowsize_new = prune(data, rowsize, minimum) np.testing.assert_equal(x_new, np.array([])) np.testing.assert_equal(rowsize_new, np.array([])) @@ -262,21 +274,27 @@ def test_prune_dates(self): np.testing.assert_equal(rowsize_new, [5, 8]) def test_prune_keep_nan(self): - x = [1, 2, np.nan, 1, 2, 1, 2, np.nan, 4] + x: list[int | float | np.float64] = [1, 2, np.nan, 1, 2, 1, 2, np.nan, 4] rowsize = [3, 2, 4] minimum = 3 - for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]: + data: ArrayTypes + for data in list[ArrayTypes]( + [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)] + ): x_new, rowsize_new = prune(data, rowsize, minimum) np.testing.assert_equal(x_new, [1, 2, np.nan, 1, 2, np.nan, 4]) np.testing.assert_equal(rowsize_new, [3, 4]) def test_prune_empty(self): - x = [] - rowsize = [] + x: list[int] = [] + rowsize: list[int] = [] minimum = 3 - for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]: + data: ArrayTypes + for data in list[ArrayTypes]( + [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)] + ): with self.assertRaises(IndexError): x_new, rowsize_new = prune(data, rowsize, minimum) @@ -285,7 +303,10 @@ def test_print_incompatible_rowsize(self): rowsize = [3, 3] minimum = 3 - for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]: + data: ArrayTypes + for data in list[ArrayTypes]( + [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)] + ): with self.assertRaises(ValueError): x_new, rowsize_new = prune(data, rowsize, minimum) @@ -304,9 +325,7 @@ def test_segment(self): def test_segment_zero_tolerance(self): x = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4] tol = 0 - self.assertIsNone( - np.testing.assert_equal(segment(x, tol), np.array([1, 2, 3, 4])) - ) + np.testing.assert_equal(segment(x, tol), np.array([1, 2, 3, 4])) def test_segment_negative_tolerance(self): x = [0, 1, 1, 1, 2, 0, 3, 3, 3, 4] @@ -316,7 +335,7 @@ def test_segment_negative_tolerance(self): def test_segment_rowsize(self): x = [0, 1, 1, 1, 2, 2, 3, 3, 3, 3, 4] tol = 0.5 - rowsize = [6, 5] + rowsize = np.array([6, 5]) segment_sizes = segment(x, tol, rowsize) self.assertTrue(isinstance(segment_sizes, np.ndarray)) self.assertTrue(np.all(segment_sizes == np.array([1, 3, 2, 4, 1]))) @@ -327,9 +346,9 @@ def test_segment_positive_and_negative_tolerance(self): self.assertTrue(np.all(segment_sizes == np.array([2, 2, 2, 2]))) def test_segment_rowsize_raises(self): - x = [0, 1, 2, 3] + x = np.array([0, 1, 2, 3]) tol = 0.5 - rowsize = [1, 2] # rowsize is too short + rowsize = np.array([1, 2]) # rowsize is too short with self.assertRaises(ValueError): segment(x, tol, rowsize) @@ -341,10 +360,11 @@ def test_segments_datetime(self): datetime(2023, 2, 1), datetime(2023, 2, 2), ] - for tol in [pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")]: - self.assertIsNone( - np.testing.assert_equal(segment(x, tol), np.array([3, 2])) - ) + tol: ToleranceTypes + for tol in list[ToleranceTypes]( + [pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")] + ): + np.testing.assert_equal(segment(x, tol), np.array([3, 2])) def test_segments_numpy(self): x = np.array( @@ -356,17 +376,19 @@ def test_segments_numpy(self): np.datetime64("2023-02-02"), ] ) - for tol in [pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")]: - self.assertIsNone( - np.testing.assert_equal(segment(x, tol), np.array([3, 2])) - ) + for tol in list[ToleranceTypes]( + [pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")] + ): + np.testing.assert_equal(segment(x, tol), np.array([3, 2])) def test_segments_pandas(self): - x = pd.to_datetime(["1/1/2023", "1/2/2023", "1/3/2023", "2/1/2023", "2/2/2023"]) - for tol in [pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")]: - self.assertIsNone( - np.testing.assert_equal(segment(x, tol), np.array([3, 2])) - ) + x: pd.Series[pd.Timestamp] = pd.to_datetime( + pd.Series(["1/1/2023", "1/2/2023", "1/3/2023", "2/1/2023", "2/2/2023"]) + ) + for tol in list[ToleranceTypes]( + [pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")] + ): + np.testing.assert_equal(segment(x, tol), np.array([3, 2])) class ragged_to_regular_tests(unittest.TestCase): @@ -570,10 +592,7 @@ def test_bad_rowsize_raises(self): with self.assertRaises(ValueError): for use_threads in [True, False]: apply_ragged( - lambda x: x**2, - np.array([1, 2, 3, 4]), - [2], - use_threads=use_threads, + lambda x: x**2, np.array([1, 2, 3, 4]), [2], use_threads=use_threads ) diff --git a/tests/raggedarray_tests.py b/tests/raggedarray_tests.py index b8fbb8ee..25c26300 100644 --- a/tests/raggedarray_tests.py +++ b/tests/raggedarray_tests.py @@ -7,6 +7,7 @@ import xarray as xr from clouddrift import RaggedArray +from clouddrift.raggedarray import DimNames NETCDF_ARCHIVE = "test_archive.nc" PARQUET_ARCHIVE = "test_archive.parquet" @@ -16,6 +17,19 @@ class raggedarray_tests(TestCase): + ra: RaggedArray + drifter_id: list[int] + rowsize: list[int] + nb_traj: int + nb_obs: int + attrs_global: dict[str, str] + variables_coords: list[tuple[str, str]] + name_coords: list[str] + name_meta: list[str] + name_data: list[str] + name_dims: dict[str, DimNames] + coord_dims: dict[str, str] + @classmethod def setUpClass(self): """ @@ -43,7 +57,7 @@ def setUpClass(self): xr_coords["time"] = ( ["obs"], - np.ones(self.rowsize[i], dtype="int") * self.drifter_id[i], + (np.ones(self.rowsize[i], dtype="int") * self.drifter_id[i]).tolist(), {"long_name": "variable time", "units": "-"}, ) @@ -55,7 +69,7 @@ def setUpClass(self): ) xr_data["temp"] = ( ["obs"], - np.random.rand(self.rowsize[i]), + np.random.rand(self.rowsize[i]).tolist(), {"long_name": "variable temp", "units": "-"}, ) diff --git a/tests/signal_tests.py b/tests/signal_tests.py index f8eed6bb..ecadc5f9 100644 --- a/tests/signal_tests.py +++ b/tests/signal_tests.py @@ -29,12 +29,18 @@ def test_imag(self): def test_real_odd(self): x = np.random.rand(99) z = analytic_signal(x) - self.assertTrue(np.allclose(x, z.real)) + if isinstance(z, np.ndarray): + self.assertTrue(np.allclose(x, z.real)) + else: + raise ValueError("Expected only the analytic singla") def test_real_even(self): x = np.random.rand(100) z = analytic_signal(x) - self.assertTrue(np.allclose(x, z.real)) + if isinstance(z, np.ndarray): + self.assertTrue(np.allclose(x, z.real)) + else: + raise ValueError("Expected only the analytic singla") def test_imag_odd(self): z = np.random.rand(99) + 1j * np.random.rand(99) @@ -51,21 +57,29 @@ def test_boundary(self): z1 = analytic_signal(x, boundary="mirror") z2 = analytic_signal(x, boundary="zeros") z3 = analytic_signal(x, boundary="periodic") - self.assertTrue(np.allclose(x, z1.real)) - self.assertTrue(np.allclose(x, z2.real)) - self.assertTrue(np.allclose(x, z3.real)) + for z in [z1, z2, z3]: + if isinstance(z, np.ndarray): + self.assertTrue(np.allclose(x, z.real)) + else: + raise ValueError("Expected only the analytic singla") def test_ndarray(self): x = np.random.random((9, 11, 13)) for n in range(3): z = analytic_signal(x, time_axis=n) - self.assertTrue(np.allclose(x, z.real)) + if isinstance(z, np.ndarray): + self.assertTrue(np.allclose(x, z.real)) + else: + raise ValueError("Expected only the analytic singla") def test_xarray(self): x = xr.DataArray(data=np.random.random((9, 11, 13))) for n in range(3): z = analytic_signal(x, time_axis=n) - self.assertTrue(np.allclose(x, z.real)) + if isinstance(z, np.ndarray): + self.assertTrue(np.allclose(x, z.real)) + else: + raise ValueError("Expected only the analytic singla") class cartesian_to_rotary_tests(unittest.TestCase): @@ -99,12 +113,15 @@ def test_invert_cartesian_to_rotary(self): v = np.random.rand(99) ua = analytic_signal(u) va = analytic_signal(v) - wp, wn = cartesian_to_rotary(ua, va) - ua_, va_ = rotary_to_cartesian(wp, wn) - self.assertTrue(np.allclose(ua, ua_)) - self.assertTrue(np.allclose(va, va_)) - self.assertTrue(np.allclose(u, np.real(ua_))) - self.assertTrue(np.allclose(v, np.real(va_))) + if isinstance(ua, np.ndarray) and isinstance(va, np.ndarray): + wp, wn = cartesian_to_rotary(ua, va) + ua_, va_ = rotary_to_cartesian(wp, wn) + self.assertTrue(np.allclose(ua, ua_)) + self.assertTrue(np.allclose(va, va_)) + self.assertTrue(np.allclose(u, np.real(ua_))) + self.assertTrue(np.allclose(v, np.real(va_))) + else: + raise ValueError("ua or va are tuples when expecting ndarray") class ellipse_parameters_tests(unittest.TestCase): diff --git a/tests/sphere_tests.py b/tests/sphere_tests.py index 3f709d21..d2bfa7bc 100644 --- a/tests/sphere_tests.py +++ b/tests/sphere_tests.py @@ -57,11 +57,9 @@ class recast_longitude_tests(unittest.TestCase): def test_same_shape(self): self.assertTrue(recast_lon(np.array([200])).shape == np.zeros(1).shape) self.assertTrue(recast_lon(np.array([200, 200])).shape == np.zeros(2).shape) - self.assertIsNone( - np.testing.assert_equal( - recast_lon(np.array([[200.5, -200.5], [200.5, -200.5]])).shape, - np.zeros((2, 2)).shape, - ) + np.testing.assert_equal( + recast_lon(np.array([[200.5, -200.5], [200.5, -200.5]])).shape, + np.zeros((2, 2)).shape, ) def test_different_lon0(self): @@ -259,10 +257,6 @@ def test_with_origin(self): self.assertTrue(np.allclose(lon, np.array([lon_origin, lon_origin]))) self.assertTrue(np.allclose(lat, np.array([lat_origin, lat_origin + 1]))) - def test_scalar_raises_error(self): - with self.assertRaises(Exception): - plane_to_sphere(0, 0) - class sphere_to_plane_tests(unittest.TestCase): def test_simple(self): @@ -331,10 +325,6 @@ def test_with_origin(self): ) ) - def test_scalar_raises_error(self): - with self.assertRaises(Exception): - sphere_to_plane(0, 0) - class sphere_to_plane_roundtrip(unittest.TestCase): def test_roundtrip(self): @@ -394,7 +384,8 @@ def test_cartesian_to_spherical_invert(self): class cartesian_to_tangentplane_tests(unittest.TestCase): def test_cartesian_to_tangentplane_values(self): up, vp = cartesian_to_tangentplane(1.0, 1.0, 1.0, 0.0, 0.0) - self.assertTrue(np.allclose((up, vp), (1.0, 1.0))) + + self.assertTrue(np.allclose([up, vp], [1.0, 1.0])) up, vp = cartesian_to_tangentplane(1.0, 1.0, 1.0, 90.0, 0.0) self.assertTrue(np.allclose((up, vp), (-1.0, 1.0))) up, vp = cartesian_to_tangentplane(1.0, 1.0, 1.0, 180.0, 0.0) diff --git a/tests/wavelet_tests.py b/tests/wavelet_tests.py index b5659564..52b222a2 100644 --- a/tests/wavelet_tests.py +++ b/tests/wavelet_tests.py @@ -22,7 +22,7 @@ def test_morse_wavelet_transform_real(self): length = 1023 radian_frequency = 2 * np.pi / np.logspace(np.log10(10), np.log10(100), 50) x = np.random.random(length) - wtx = morse_wavelet_transform(x, 3, 10, radian_frequency) + wtx = morse_wavelet_transform(x, 3, 10, radian_frequency, False) wavelet, _ = morse_wavelet(length, 3, 10, radian_frequency) wtx2 = wavelet_transform(x, wavelet) self.assertTrue(np.allclose(wtx, wtx2)) @@ -82,7 +82,7 @@ def test_morse_wavelet_transform_cos(self): f = 0.2 t = np.arange(0, 1000) x = np.cos(2 * np.pi * t * f) - wtx = morse_wavelet_transform(x, 3, 10, 2 * np.pi * np.array([f])) + wtx = morse_wavelet_transform(x, 3, 10, 2 * np.pi * np.array([f]), False) self.assertTrue(np.isclose(np.var(x), 0.5 * np.var(wtx), atol=1e-2)) def test_morse_wavelet_transform_exp(self): @@ -284,11 +284,16 @@ def test_morse_freq_beta_zero(self): class morse_logspace_freq_tests(unittest.TestCase): def test_morse_logspace_freq_high(self): # here we are not testing the morse_logspace_freq function - gamma = np.array([3]) - beta = np.array([4]) + gamma = 4 + beta = 4 eta = 0.1 - fhigh = _morsehigh(gamma, beta, eta) - _, waveletfft = morse_wavelet(10000, gamma, beta, fhigh) + fhigh = _morsehigh(np.array([gamma]), np.array([beta]), eta) + + if isinstance(fhigh, np.ndarray): + _, waveletfft = morse_wavelet(10000, gamma, beta, fhigh) + else: + _, waveletfft = morse_wavelet(10000, gamma, beta, np.array(fhigh)) + self.assertTrue( np.isclose( np.abs(0.5 * waveletfft[0, 0, int(10000 / 2) - 1]), eta, atol=1e-3