diff --git a/src/earthkit/meteo/stats/array/quantiles.py b/src/earthkit/meteo/stats/array/quantiles.py index 33e1cd6..0f9109a 100644 --- a/src/earthkit/meteo/stats/array/quantiles.py +++ b/src/earthkit/meteo/stats/array/quantiles.py @@ -59,6 +59,7 @@ def iter_quantiles( if method == "sort": arr = np.asarray(arr) arr.sort(axis=axis) + missing = np.isnan(arr).any(axis=axis) for q in qs: if method == "numpy": @@ -74,4 +75,5 @@ def iter_quantiles( tmp = arr.take(min(j + 1, m - 1), axis=axis) tmp *= x quantile += tmp + quantile[missing] = np.nan yield quantile diff --git a/tests/stats/test_stats.py b/tests/stats/test_stats.py index 7709313..ba7aea4 100644 --- a/tests/stats/test_stats.py +++ b/tests/stats/test_stats.py @@ -65,3 +65,16 @@ def test_quantiles(method): [5, 19, 12, 3, 45, 48, 8, 9, 7], ] ) + + +def test_quantiles_nans(): + arr = np.random.rand(100, 100, 100) + arr.ravel()[np.random.choice(arr.size, 100000, replace=False)] = np.nan + qs = [0.0, 0.25, 0.5, 0.75, 1.0] + sort = [ + quantile for quantile in stats.iter_quantiles(arr.copy(), qs, method="sort") + ] + numpy = [ + quantile for quantile in stats.iter_quantiles(arr.copy(), qs, method="numpy") + ] + assert np.all(np.isclose(sort, numpy, equal_nan=True))