diff --git a/requirements.txt b/requirements.txt index 65c12d9..d7ec62d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ zarr>=2.0.0 orjson>=3.0.0 pydantic>=2.0.0 more-itertools>=10.0.0 -loguru>=0.7.0 \ No newline at end of file +loguru>=0.7.0 +numbagg>=0.8.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 3855a6c..cfe01c7 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='TensorDB', - version='0.30.1', + version='0.30.3', description='Database based in a file system storage combined with Xarray and Zarr', author='Joseph Nowak', author_email='josephgonowak97@gmail.com', diff --git a/tensordb/algorithms.py b/tensordb/algorithms.py index 997e8e2..fc597a9 100644 --- a/tensordb/algorithms.py +++ b/tensordb/algorithms.py @@ -5,6 +5,7 @@ import dask.array as da import numpy as np import xarray as xr +import numbagg as nba from dask.distributed import Client from scipy.stats import rankdata @@ -731,3 +732,68 @@ def structured_inequality(x, top): ) return bitmask.expand_dims({tie_breaker_dim: new_data.coords[tie_breaker_dim]}) + + @classmethod + def rolling_overlap( + cls, + new_data: Union[xr.DataArray, xr.Dataset], + func: Callable, + dim: str, + window: int, + window_margin: int, + min_periods: int = None, + apply_ffill: bool = True, + validate_window_size: bool = True + ): + assert window_margin >= window + + if validate_window_size: + window_margin = min(window_margin, new_data.sizes[dim]) + + assert window > 0 + + if isinstance(new_data, xr.Dataset): + return new_data.map( + cls.rolling_overlap, + func=func, + window=window, + dim=dim, + window_margin=window_margin, + min_periods=min_periods, + ) + + min_periods = window if min_periods is None else min_periods + axis = new_data.dims.index(dim) + data = new_data.data + + def _apply_on_valid(x): + x = x.copy() + bitmask = ~np.isnan(x) + filter_x = x[bitmask] + + min_window = window + if validate_window_size: + min_window = min(window, len(filter_x)) + if min_window == 0: + return x + + func(filter_x, window=min_window, min_count=min_periods, out=filter_x) + + x[bitmask] = filter_x + if apply_ffill: + nba.ffill(arr=x, limit=len(x), out=x) + return x + + def apply_on_valid(x): + return np.apply_along_axis(func1d=_apply_on_valid, axis=axis, arr=x) + + data = data.map_overlap( + apply_on_valid, + depth={axis: window_margin}, + boundary=None, + meta=data, + trim=True, + ) + + data = xr.DataArray(data, dims=new_data.dims, coords=new_data.coords) + return data diff --git a/tensordb/tests/test_algorithms.py b/tensordb/tests/test_algorithms.py index 79a6b49..9d564f8 100644 --- a/tensordb/tests/test_algorithms.py +++ b/tensordb/tests/test_algorithms.py @@ -2,6 +2,7 @@ import pandas as pd import pytest import xarray as xr +import numbagg as nba from tensordb.algorithms import Algorithms @@ -445,5 +446,49 @@ def test_cumulative_on_sort(dim, ascending, func): assert result.equals(expected) +@pytest.mark.parametrize("window", list(range(1, 4))) +@pytest.mark.parametrize("apply_ffill", [True, False]) +def test_rolling_overlap(window, apply_ffill): + arr = xr.DataArray( + [ + [1, np.nan, 3], + [np.nan, 4, 6], + [np.nan, 5, np.nan], + [3, np.nan, 7], + [7, 6, np.nan], + ], + dims=["a", "b"], + coords={"a": list(range(5)), "b": list(range(3))}, + ).chunk(a=3, b=1) + df = pd.DataFrame(arr.values.T, arr.b.values, arr.a.values).stack(dropna=False) + for window_margin in range(window, 6): + rolling_arr = Algorithms.rolling_overlap( + arr, + func=nba.move_mean, + window=window, + dim="a", + window_margin=window_margin, + min_periods=1, + apply_ffill=apply_ffill + ) + + expected = df.dropna() + expected = ( + expected.groupby(level=0) + .rolling(window=window, min_periods=1) + .mean() + ) + expected = expected.droplevel(0).unstack(0) + + expected = xr.DataArray(expected.values, coords=arr.coords, dims=arr.dims) + if apply_ffill: + expected = expected.ffill("a") + + if window_margin == 2 and window == 2: + assert ~expected.equals(rolling_arr) + else: + assert expected.equals(rolling_arr) + + if __name__ == "__main__": pass diff --git a/tensordb/tests/test_tools.py b/tensordb/tests/test_tools.py index 59757dd..e8b2e09 100644 --- a/tensordb/tests/test_tools.py +++ b/tensordb/tests/test_tools.py @@ -35,7 +35,7 @@ def test_xarray_from_func_data_array(self): coords={"a": list(range(6)), "b": list(range(8))}, chunks=[2, 3], dtypes=np.float64, - func_parameters={}, + kwargs={}, ) assert data.equals(self.data_array.sel(**data.coords)) @@ -47,7 +47,7 @@ def test_xarray_from_func_dataset(self): chunks=[2, 3], dtypes=[np.float64, np.float64], data_names=["first", "second"], - func_parameters={}, + kwargs={}, ) assert data.equals(self.dataset.sel(data.coords)) diff --git a/tensordb/utils/tools.py b/tensordb/utils/tools.py index 21a1957..80841e3 100644 --- a/tensordb/utils/tools.py +++ b/tensordb/utils/tools.py @@ -93,7 +93,8 @@ def xarray_from_func( chunks: Union[List[Union[int, None]], Dict[Hashable, int]], dtypes: Union[List[Any], Any], data_names: Union[List[Hashable], str] = None, - func_parameters: Dict[str, Any] = None, + args: List[Any] = None, + kwargs: Dict[str, Any] = None, ) -> Union[xr.DataArray, xr.Dataset]: """ Equivalent of dask fromfunction but it sends the xarray coords of every chunk instead of the positions @@ -132,7 +133,10 @@ def xarray_from_func( Indicate the names of the different DataArray inside your Dataset. The data_names must be aligned with dtypes, in other case it will raise an Error. - func_parameters: Dict[str, Any], default None + args: List[Any] = None + Extra parameters for the function + + kwargs: Dict[str, Any], default None Extra parameters for the function """ @@ -141,7 +145,8 @@ def xarray_from_func( chunks = [ len(coords[dim]) if chunk is None else chunk for chunk, dim in zip(chunks, dims) ] - func_parameters = {} if func_parameters is None else func_parameters + kwargs = {} if kwargs is None else kwargs + args = [] if args is None else args if data_names is None or isinstance(data_names, str): arr = empty_xarray(dims, coords, chunks, dtypes) @@ -159,4 +164,4 @@ def xarray_from_func( coords=coords, ) - return arr.map_blocks(func, kwargs=func_parameters, template=arr) + return arr.map_blocks(func, args=args, kwargs=kwargs, template=arr)