Skip to content

Commit

Permalink
Merge pull request #118 from josephnowak/feature/add-new-rolling-over…
Browse files Browse the repository at this point in the history
…lap-and-standarize-xarray-from-func

Feature/add new rolling overlap and standarize xarray from func
  • Loading branch information
josephnowak authored Apr 10, 2024
2 parents 4d7964f + eae0e20 commit db9750f
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 8 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ zarr>=2.0.0
orjson>=3.0.0
pydantic>=2.0.0
more-itertools>=10.0.0
loguru>=0.7.0
loguru>=0.7.0
numbagg>=0.8.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='TensorDB',
version='0.30.1',
version='0.30.3',
description='Database based in a file system storage combined with Xarray and Zarr',
author='Joseph Nowak',
author_email='josephgonowak97@gmail.com',
Expand Down
66 changes: 66 additions & 0 deletions tensordb/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dask.array as da
import numpy as np
import xarray as xr
import numbagg as nba
from dask.distributed import Client
from scipy.stats import rankdata

Expand Down Expand Up @@ -731,3 +732,68 @@ def structured_inequality(x, top):
)

return bitmask.expand_dims({tie_breaker_dim: new_data.coords[tie_breaker_dim]})

@classmethod
def rolling_overlap(
cls,
new_data: Union[xr.DataArray, xr.Dataset],
func: Callable,
dim: str,
window: int,
window_margin: int,
min_periods: int = None,
apply_ffill: bool = True,
validate_window_size: bool = True
):
assert window_margin >= window

if validate_window_size:
window_margin = min(window_margin, new_data.sizes[dim])

assert window > 0

if isinstance(new_data, xr.Dataset):
return new_data.map(
cls.rolling_overlap,
func=func,
window=window,
dim=dim,
window_margin=window_margin,
min_periods=min_periods,
)

min_periods = window if min_periods is None else min_periods
axis = new_data.dims.index(dim)
data = new_data.data

def _apply_on_valid(x):
x = x.copy()
bitmask = ~np.isnan(x)
filter_x = x[bitmask]

min_window = window
if validate_window_size:
min_window = min(window, len(filter_x))
if min_window == 0:
return x

func(filter_x, window=min_window, min_count=min_periods, out=filter_x)

x[bitmask] = filter_x
if apply_ffill:
nba.ffill(arr=x, limit=len(x), out=x)
return x

def apply_on_valid(x):
return np.apply_along_axis(func1d=_apply_on_valid, axis=axis, arr=x)

data = data.map_overlap(
apply_on_valid,
depth={axis: window_margin},
boundary=None,
meta=data,
trim=True,
)

data = xr.DataArray(data, dims=new_data.dims, coords=new_data.coords)
return data
45 changes: 45 additions & 0 deletions tensordb/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import pytest
import xarray as xr
import numbagg as nba

from tensordb.algorithms import Algorithms

Expand Down Expand Up @@ -445,5 +446,49 @@ def test_cumulative_on_sort(dim, ascending, func):
assert result.equals(expected)


@pytest.mark.parametrize("window", list(range(1, 4)))
@pytest.mark.parametrize("apply_ffill", [True, False])
def test_rolling_overlap(window, apply_ffill):
arr = xr.DataArray(
[
[1, np.nan, 3],
[np.nan, 4, 6],
[np.nan, 5, np.nan],
[3, np.nan, 7],
[7, 6, np.nan],
],
dims=["a", "b"],
coords={"a": list(range(5)), "b": list(range(3))},
).chunk(a=3, b=1)
df = pd.DataFrame(arr.values.T, arr.b.values, arr.a.values).stack(dropna=False)
for window_margin in range(window, 6):
rolling_arr = Algorithms.rolling_overlap(
arr,
func=nba.move_mean,
window=window,
dim="a",
window_margin=window_margin,
min_periods=1,
apply_ffill=apply_ffill
)

expected = df.dropna()
expected = (
expected.groupby(level=0)
.rolling(window=window, min_periods=1)
.mean()
)
expected = expected.droplevel(0).unstack(0)

expected = xr.DataArray(expected.values, coords=arr.coords, dims=arr.dims)
if apply_ffill:
expected = expected.ffill("a")

if window_margin == 2 and window == 2:
assert ~expected.equals(rolling_arr)
else:
assert expected.equals(rolling_arr)


if __name__ == "__main__":
pass
4 changes: 2 additions & 2 deletions tensordb/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_xarray_from_func_data_array(self):
coords={"a": list(range(6)), "b": list(range(8))},
chunks=[2, 3],
dtypes=np.float64,
func_parameters={},
kwargs={},
)
assert data.equals(self.data_array.sel(**data.coords))

Expand All @@ -47,7 +47,7 @@ def test_xarray_from_func_dataset(self):
chunks=[2, 3],
dtypes=[np.float64, np.float64],
data_names=["first", "second"],
func_parameters={},
kwargs={},
)
assert data.equals(self.dataset.sel(data.coords))

Expand Down
13 changes: 9 additions & 4 deletions tensordb/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def xarray_from_func(
chunks: Union[List[Union[int, None]], Dict[Hashable, int]],
dtypes: Union[List[Any], Any],
data_names: Union[List[Hashable], str] = None,
func_parameters: Dict[str, Any] = None,
args: List[Any] = None,
kwargs: Dict[str, Any] = None,
) -> Union[xr.DataArray, xr.Dataset]:
"""
Equivalent of dask fromfunction but it sends the xarray coords of every chunk instead of the positions
Expand Down Expand Up @@ -132,7 +133,10 @@ def xarray_from_func(
Indicate the names of the different DataArray inside your Dataset.
The data_names must be aligned with dtypes, in other case it will raise an Error.
func_parameters: Dict[str, Any], default None
args: List[Any] = None
Extra parameters for the function
kwargs: Dict[str, Any], default None
Extra parameters for the function
"""
Expand All @@ -141,7 +145,8 @@ def xarray_from_func(
chunks = [
len(coords[dim]) if chunk is None else chunk for chunk, dim in zip(chunks, dims)
]
func_parameters = {} if func_parameters is None else func_parameters
kwargs = {} if kwargs is None else kwargs
args = [] if args is None else args

if data_names is None or isinstance(data_names, str):
arr = empty_xarray(dims, coords, chunks, dtypes)
Expand All @@ -159,4 +164,4 @@ def xarray_from_func(
coords=coords,
)

return arr.map_blocks(func, kwargs=func_parameters, template=arr)
return arr.map_blocks(func, args=args, kwargs=kwargs, template=arr)

0 comments on commit db9750f

Please sign in to comment.