Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 fix typing issues from checking untyped defs, fixes #509 #510

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions clouddrift/adapters/gdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import os
import tempfile
import typing

import numpy as np
import pandas as pd
import xarray as xr
from numpy.typing import NDArray

from clouddrift.adapters.utils import download_with_progress
from clouddrift.raggedarray import DimNames
Expand Down Expand Up @@ -269,7 +271,9 @@ def str_to_float(value: str, default: float = np.nan) -> float:
return default


def cut_str(value: str, max_length: int) -> np.chararray:
def cut_str(
value: str, max_length: int
) -> np.chararray[typing.Any, np.dtype[np.bytes_]]:
"""Cut a string to a specific length and return it as a numpy chararray.

Parameters
Expand All @@ -289,7 +293,7 @@ def cut_str(value: str, max_length: int) -> np.chararray:
return charar


def drogue_presence(lost_time, time) -> np.ndarray:
def drogue_presence(lost_time, time) -> NDArray[typing.Any]:
"""Create drogue status from the drogue lost time and the trajectory time.

Parameters
Expand Down
29 changes: 17 additions & 12 deletions clouddrift/adapters/gdp/gdpsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import logging
import os
import tempfile
import typing
import warnings
from collections import defaultdict
from concurrent.futures import Future, ProcessPoolExecutor, as_completed
from typing import Callable

import numpy as np
import numpy.typing as np_typing
import pandas as pd
import xarray as xr
from tqdm.asyncio import tqdm
Expand Down Expand Up @@ -56,7 +57,7 @@
"death_code",
]

_VARS_FILL_MAP: dict = {
_VARS_FILL_MAP: dict[str, int | str | np.datetime64] = {
"wmo_number": -999,
"program_number": -999,
"buoys_type": "N/A",
Expand All @@ -70,7 +71,7 @@
"death_code": -999,
}

_VAR_DTYPES: dict = {
_VAR_DTYPES: dict[str, type | np.dtype[np.datetime64]] = {
"rowsize": np.int64,
"wmo_number": np.int64,
"program_number": np.int64,
Expand Down Expand Up @@ -126,7 +127,7 @@
}


VARS_ATTRS: dict = {
VARS_ATTRS: dict[str, dict[str, str]] = {
"id": {"long_name": "Global Drifter Program Buoy ID", "units": "-"},
"rowsize": {
"long_name": "Number of observations per trajectory",
Expand Down Expand Up @@ -334,7 +335,9 @@
return dataset


def _apply_remove(df: pd.DataFrame, filters: list[Callable]) -> pd.DataFrame:
def _apply_remove(
df: pd.DataFrame, filters: list[typing.Callable[..., typing.Any]]
) -> pd.DataFrame:
temp_df = df
for filter_ in filters:
mask = filter_(temp_df)
Expand All @@ -344,7 +347,7 @@

def _apply_transform(
df: pd.DataFrame,
transforms: dict[str, tuple[list[str], Callable]],
transforms: dict[str, tuple[list[str], typing.Callable[..., typing.Any]]],
) -> pd.DataFrame:
tmp_df = df
for output_col in transforms.keys():
Expand All @@ -359,8 +362,10 @@


def _parse_datetime_with_day_ratio(
month_series: np.ndarray, day_series: np.ndarray, year_series: np.ndarray
) -> np.ndarray:
month_series: np_typing.NDArray[np.float32],
day_series: np_typing.NDArray[np.float32],
year_series: np_typing.NDArray[np.float32],
) -> np_typing.NDArray[np.datetime64]:
values = list()
for month, day_with_ratio, year in zip(month_series, day_series, year_series):
day = day_with_ratio // 1
Expand Down Expand Up @@ -479,7 +484,7 @@
)

sort_coord = traj_dataset.coords["obs_index"]
vals: np.ndarray = sort_coord.data
vals: np_typing.NDArray[np.int64] = sort_coord.data

Check warning on line 487 in clouddrift/adapters/gdp/gdpsource.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/gdp/gdpsource.py#L487

Added line #L487 was not covered by tests
sort_coord_dim = sort_coord.dims[-1]
sort_key = vals.argsort()

Expand Down Expand Up @@ -531,8 +536,8 @@
chunksize=chunk_size,
)

joblist = list[Future]()
jobmap = dict[Future, pd.DataFrame]()
joblist = list[Future[dict[int, xr.Dataset]]]()
jobmap = dict[Future[dict[int, xr.Dataset]], pd.DataFrame]()
for idx, chunk in enumerate(file_chunks):
if max_chunks is not None and idx >= max_chunks:
break
Expand Down Expand Up @@ -568,7 +573,7 @@
drifter_chunked_datasets[id_].append(drifter_ds)
bar.update()

combine_jobmap = dict[Future, int]()
combine_jobmap = dict[Future[xr.Dataset], int]()
for id_ in drifter_chunked_datasets.keys():
datasets = drifter_chunked_datasets[id_]

Expand Down
4 changes: 3 additions & 1 deletion clouddrift/adapters/hurdat2.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,9 @@ def to_raggedarray(
preprocess_func=lambda idx: track_data[idx].to_xarray_dataset(),
attrs_global=TrackData.global_attrs,
attrs_variables={
field.name: field.metadata
field.name: dict(
field.metadata
) # type cast needed for static type analysis
for field in fields(HeaderLine) + fields(DataLine)
},
)
Expand Down
21 changes: 11 additions & 10 deletions clouddrift/adapters/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import concurrent.futures
import logging
import os
import typing
import urllib
from datetime import datetime
from io import BufferedIOBase, BufferedWriter
from typing import Callable, Sequence

import requests
from tenacity import (
RetryCallState,
WrappedFn,
retry,
retry_if_exception,
stop_after_attempt,
Expand All @@ -30,7 +29,10 @@
_CHUNK_SIZE = 1_048_576 # 1MiB
_logger = logging.getLogger(__name__)

standard_retry_protocol: Callable[[WrappedFn], WrappedFn] = retry(
_Func = typing.Callable[..., typing.Any]
_Wrapper = typing.Callable[[_Func], _Func]

standard_retry_protocol: _Wrapper = retry(
retry=retry_if_exception(
lambda ex: isinstance(
ex,
Expand All @@ -51,19 +53,19 @@


def download_with_progress(
download_map: Sequence[
download_map: typing.Sequence[
tuple[str, BufferedIOBase | str] | tuple[str, BufferedIOBase | str, float]
],
show_list_progress: bool | None = None,
desc: str = "Downloading files",
custom_retry_protocol: Callable[[WrappedFn], WrappedFn] | None = None,
custom_retry_protocol: _Wrapper | None = None,
):
if show_list_progress is None:
show_list_progress = len(download_map) > 20
if custom_retry_protocol is None:
retry_protocol = standard_retry_protocol
else:
retry_protocol = custom_retry_protocol # type: ignore
retry_protocol = custom_retry_protocol

Check warning on line 68 in clouddrift/adapters/utils.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/utils.py#L68

Added line #L68 was not covered by tests

executor = concurrent.futures.ThreadPoolExecutor()
futures: dict[
Expand Down Expand Up @@ -156,10 +158,10 @@
)

_logger.debug(f"Downloading from {url} to {output}...")
bar = None

with requests.get(url, timeout=5, stream=True) as response:
buffer: BufferedWriter | BufferedIOBase | None = None
bar = None

try:
if isinstance(output, (str,)):
buffer = open(output, "wb")
Expand All @@ -179,15 +181,14 @@
nrows=2,
disable=_DISABLE_SHOW_PROGRESS,
)

for chunk in response.iter_content(_CHUNK_SIZE):
if not chunk:
break
buffer.write(chunk)
if bar is not None:
bar.update(len(chunk))
finally:
if response is not None:
response.close()
if bar is not None:
bar.close()
if buffer is not None and isinstance(output, (str,)):
Expand Down
11 changes: 6 additions & 5 deletions clouddrift/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""

import numpy as np
import pandas as pd
import numpy.typing as np_typing
import xarray as xr

import clouddrift.typing as cd_typing
from clouddrift.sphere import (
EARTH_RADIUS_METERS,
bearing,
Expand All @@ -21,9 +22,9 @@


def kinetic_energy(
u: float | list | np.ndarray | xr.DataArray | pd.Series,
v: float | list | np.ndarray | xr.DataArray | pd.Series | None = None,
) -> float | np.ndarray | xr.DataArray:
u: float | cd_typing.ArrayTypes,
v: float | cd_typing.ArrayTypes | None = None,
) -> float | np_typing.NDArray[np.float64] | xr.DataArray:
"""Compute kinetic energy from zonal and meridional velocities.

Parameters
Expand Down Expand Up @@ -531,7 +532,7 @@ def velocity_from_position(
coord_system: str = "spherical",
difference_scheme: str = "forward",
time_axis: int = -1,
) -> tuple[xr.DataArray, xr.DataArray]:
) -> tuple[np.ndarray, np.ndarray]:
"""Compute velocity from arrays of positions and time.

x and y can be provided as longitude and latitude in degrees if
Expand Down
36 changes: 22 additions & 14 deletions clouddrift/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,34 @@
Transformational and inquiry functions for ragged arrays.
"""

import typing
import warnings
from collections.abc import Callable, Iterable
from concurrent import futures
from datetime import timedelta

import numpy as np
import numpy.typing as np_typing
import pandas as pd
import xarray as xr

import clouddrift.typing as cd_typing

_ArrayOutput = typing.TypeVar(
"_ArrayOutput", bound=tuple[np.ndarray, np.ndarray] | np.ndarray
)


def apply_ragged(
func: callable,
arrays: list[np.ndarray | xr.DataArray] | np.ndarray | xr.DataArray,
rowsize: list[int] | np.ndarray[int] | xr.DataArray,
*args: tuple,
func: Callable[..., _ArrayOutput],
arrays: cd_typing.ArrayTypes,
rowsize: cd_typing.ArrayTypes,
*args: typing.Any,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might help to keep things consistent to also update the docstrings here.

rows: int | Iterable[int] = None,
axis: int = 0,
executor: futures.Executor = futures.ThreadPoolExecutor(max_workers=None),
**kwargs: dict,
) -> tuple[np.ndarray] | np.ndarray:
**kwargs: typing.Any,
) -> _ArrayOutput:
"""Apply a function to a ragged array.

The function ``func`` will be applied to each contiguous row of ``arrays`` as
Expand Down Expand Up @@ -423,7 +431,7 @@ def regular_to_ragged(
return array[valid], np.sum(valid, axis=1)


def rowsize_to_index(rowsize: list | np.ndarray | xr.DataArray) -> np.ndarray:
def rowsize_to_index(rowsize: cd_typing.ArrayTypes) -> np.ndarray:
"""Convert a list of row sizes to a list of indices.

This function is typically used to obtain the indices of data rows organized
Expand All @@ -450,10 +458,10 @@ def rowsize_to_index(rowsize: list | np.ndarray | xr.DataArray) -> np.ndarray:


def segment(
x: np.ndarray,
x: cd_typing.ArrayTypes,
tolerance: float | np.timedelta64 | timedelta | pd.Timedelta,
rowsize: np.ndarray[int] = None,
) -> np.ndarray[int]:
rowsize: np_typing.NDArray[np.int64] | None = None,
) -> np_typing.NDArray[np.int64]:
"""Divide an array into segments based on a tolerance value.

Parameters
Expand Down Expand Up @@ -787,11 +795,11 @@ def subset(


def unpack(
ragged_array: np.ndarray,
rowsize: np.ndarray[int],
rows: int | Iterable[int] = None,
ragged_array: cd_typing.ArrayTypes,
rowsize: cd_typing.ArrayTypes,
rows: int | np.int64 | Iterable[int] | None = None,
axis: int = 0,
) -> list[np.ndarray]:
) -> list[np_typing.NDArray[typing.Any]]:
"""Unpack a ragged array into a list of regular arrays.

Unpacking a ``np.ndarray`` ragged array is about 2 orders of magnitude
Expand Down
Loading