diff --git a/sgkit/stats/cohort_numba_fns.py b/sgkit/stats/cohort_numba_fns.py index 796f1190..a19f4180 100644 --- a/sgkit/stats/cohort_numba_fns.py +++ b/sgkit/stats/cohort_numba_fns.py @@ -1,9 +1,9 @@ from functools import wraps from typing import Callable -import dask.array as da import numpy as np +import sgkit.distarray as da from sgkit.accelerate import numba_guvectorize from ..typing import ArrayLike diff --git a/sgkit/tests/test_cohort_numba_fns.py b/sgkit/tests/test_cohort_numba_fns.py index 2239e3e5..3a131258 100644 --- a/sgkit/tests/test_cohort_numba_fns.py +++ b/sgkit/tests/test_cohort_numba_fns.py @@ -1,7 +1,7 @@ -import dask.array as da import numpy as np import pytest +import sgkit.distarray as da from sgkit.stats.cohort_numba_fns import ( cohort_mean, cohort_nanmean, @@ -41,7 +41,7 @@ def _cohort_reduction(func, x, cohort, n, axis=-1): _random_cohort_data((20, 20), n=3, axis=-1, missing=0.3), _random_cohort_data((7, 103, 4), n=5, axis=1, scale=7, missing=0.3), _random_cohort_data( - ((3, 4), (50, 50, 3), 4), n=5, axis=1, scale=7, dtype=np.uint8 + ((4, 3), (50, 50, 3), 4), n=5, axis=1, scale=7, dtype=np.uint8 ), _random_cohort_data( ((6, 6), (50, 50, 7), (3, 1)), n=5, axis=1, scale=7, missing=0.3