diff --git a/clouddrift/analysis.py b/clouddrift/analysis.py index 16f490c6..4434ac05 100644 --- a/clouddrift/analysis.py +++ b/clouddrift/analysis.py @@ -248,13 +248,13 @@ def prune( """ ragged = apply_ragged( - lambda x, min_len: x if len(x) >= min_len else [], + lambda x, min_len: x if len(x) >= min_len else np.empty(0, dtype=x.dtype), np.array(ragged), rowsize, min_len=min_rowsize, ) rowsize = apply_ragged( - lambda x, min_len: x if x >= min_len else [], + lambda x, min_len: x if x >= min_len else np.empty(0, dtype=x.dtype), np.array(rowsize), np.ones_like(rowsize), min_len=min_rowsize, diff --git a/tests/analysis_tests.py b/tests/analysis_tests.py index a1e47aaa..12d481a3 100644 --- a/tests/analysis_tests.py +++ b/tests/analysis_tests.py @@ -236,6 +236,29 @@ def test_prune_all_smaller(self): np.testing.assert_equal(x_new, np.array([])) np.testing.assert_equal(rowsize_new, np.array([])) + def test_prune_dates(self): + a = pd.date_range( + start=pd.to_datetime("1/1/2018"), + end=pd.to_datetime("1/03/2018"), + ) + + b = pd.date_range( + start=pd.to_datetime("1/1/2018"), + end=pd.to_datetime("1/05/2018"), + ) + + c = pd.date_range( + start=pd.to_datetime("1/1/2018"), + end=pd.to_datetime("1/08/2018"), + ) + + x = np.concatenate((a, b, c)) + rowsize = [len(v) for v in [a, b, c]] + + x_new, rowsize_new = prune(x, rowsize, 5) + np.testing.assert_equal(x_new, np.concatenate((b, c))) + np.testing.assert_equal(rowsize_new, [5, 8]) + def test_prune_keep_nan(self): x = [1, 2, np.nan, 1, 2, 1, 2, np.nan, 4] rowsize = [3, 2, 4]