diff --git a/pandas/tests/window/common.py b/pandas/tests/window/common.py index c3648bc619c50..e1f322622de48 100644 --- a/pandas/tests/window/common.py +++ b/pandas/tests/window/common.py @@ -348,3 +348,34 @@ def get_result(obj, obj2=None): result.index = result.index.droplevel(1) expected = get_result(self.frame[1], self.frame[5]) tm.assert_series_equal(result, expected, check_names=False) + + +def ew_func(A, B, com, name, **kwargs): + return getattr(A.ewm(com, **kwargs), name)(B) + + +def check_binary_ew(name, A, B): + + result = ew_func(A=A, B=B, com=20, name=name, min_periods=5) + assert np.isnan(result.values[:14]).all() + assert not np.isnan(result.values[14:]).any() + + +def check_binary_ew_min_periods(name, min_periods, A, B): + # GH 7898 + result = ew_func(A, B, 20, name=name, min_periods=min_periods) + # binary functions (ewmcov, ewmcorr) with bias=False require at + # least two values + assert np.isnan(result.values[:11]).all() + assert not np.isnan(result.values[11:]).any() + + # check series of length 0 + empty = Series([], dtype=np.float64) + result = ew_func(empty, empty, 50, name=name, min_periods=min_periods) + tm.assert_series_equal(result, empty) + + # check series of length 1 + result = ew_func( + Series([1.0]), Series([1.0]), 50, name=name, min_periods=min_periods + ) + tm.assert_series_equal(result, Series([np.NaN])) diff --git a/pandas/tests/window/moments/conftest.py b/pandas/tests/window/moments/conftest.py new file mode 100644 index 0000000000000..2002f4d0bff43 --- /dev/null +++ b/pandas/tests/window/moments/conftest.py @@ -0,0 +1,20 @@ +import numpy as np +from numpy.random import randn +import pytest + +from pandas import Series + + +@pytest.fixture +def binary_ew_data(): + A = Series(randn(50), index=np.arange(50)) + B = A[2:] + randn(48) + + A[:10] = np.NaN + B[-10:] = np.NaN + return A, B + + +@pytest.fixture(params=[0, 1, 2]) +def min_periods(request): + return request.param diff --git a/pandas/tests/window/moments/test_moments_ewm.py b/pandas/tests/window/moments/test_moments_ewm.py index bf2bd1420b7f4..46cd503e95e11 100644 --- a/pandas/tests/window/moments/test_moments_ewm.py +++ b/pandas/tests/window/moments/test_moments_ewm.py @@ -4,7 +4,13 @@ import pandas as pd from pandas import DataFrame, Series, concat -from pandas.tests.window.common import Base, ConsistencyBase +from pandas.tests.window.common import ( + Base, + ConsistencyBase, + check_binary_ew, + check_binary_ew_min_periods, + ew_func, +) import pandas.util.testing as tm @@ -216,6 +222,9 @@ def _check_ew(self, name=None, preserve_nan=False): if preserve_nan: assert result[self._nan_locs].isna().all() + @pytest.mark.parametrize("min_periods", [0, 1]) + @pytest.mark.parametrize("name", ["mean", "var", "vol"]) + def test_ew_min_periods(self, min_periods, name): # excluding NaNs correctly arr = randn(50) arr[:10] = np.NaN @@ -228,31 +237,30 @@ def _check_ew(self, name=None, preserve_nan=False): assert result[:11].isna().all() assert not result[11:].isna().any() - for min_periods in (0, 1): - result = getattr(s.ewm(com=50, min_periods=min_periods), name)() - if name == "mean": - assert result[:10].isna().all() - assert not result[10:].isna().any() - else: - # ewm.std, ewm.vol, ewm.var (with bias=False) require at least - # two values - assert result[:11].isna().all() - assert not result[11:].isna().any() - - # check series of length 0 - result = getattr( - Series(dtype=object).ewm(com=50, min_periods=min_periods), name - )() - tm.assert_series_equal(result, Series(dtype="float64")) - - # check series of length 1 - result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)() - if name == "mean": - tm.assert_series_equal(result, Series([1.0])) - else: - # ewm.std, ewm.vol, ewm.var with bias=False require at least - # two values - tm.assert_series_equal(result, Series([np.NaN])) + result = getattr(s.ewm(com=50, min_periods=min_periods), name)() + if name == "mean": + assert result[:10].isna().all() + assert not result[10:].isna().any() + else: + # ewm.std, ewm.vol, ewm.var (with bias=False) require at least + # two values + assert result[:11].isna().all() + assert not result[11:].isna().any() + + # check series of length 0 + result = getattr( + Series(dtype=object).ewm(com=50, min_periods=min_periods), name + )() + tm.assert_series_equal(result, Series(dtype="float64")) + + # check series of length 1 + result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)() + if name == "mean": + tm.assert_series_equal(result, Series([1.0])) + else: + # ewm.std, ewm.vol, ewm.var with bias=False require at least + # two values + tm.assert_series_equal(result, Series([np.NaN])) # pass in ints result2 = getattr(Series(np.arange(50)).ewm(span=10), name)() @@ -263,53 +271,27 @@ class TestEwmMomentsConsistency(ConsistencyBase): def setup_method(self, method): self._create_data() - def test_ewmcov(self): - self._check_binary_ew("cov") - def test_ewmcov_pairwise(self): self._check_pairwise_moment("ewm", "cov", span=10, min_periods=5) - def test_ewmcorr(self): - self._check_binary_ew("corr") + @pytest.mark.parametrize("name", ["cov", "corr"]) + def test_ewm_corr_cov(self, name, min_periods, binary_ew_data): + A, B = binary_ew_data + + check_binary_ew(name="corr", A=A, B=B) + check_binary_ew_min_periods("corr", min_periods, A, B) def test_ewmcorr_pairwise(self): self._check_pairwise_moment("ewm", "corr", span=10, min_periods=5) - def _check_binary_ew(self, name): - def func(A, B, com, **kwargs): - return getattr(A.ewm(com, **kwargs), name)(B) - - A = Series(randn(50), index=np.arange(50)) - B = A[2:] + randn(48) - - A[:10] = np.NaN - B[-10:] = np.NaN - - result = func(A, B, 20, min_periods=5) - assert np.isnan(result.values[:14]).all() - assert not np.isnan(result.values[14:]).any() - - # GH 7898 - for min_periods in (0, 1, 2): - result = func(A, B, 20, min_periods=min_periods) - # binary functions (ewmcov, ewmcorr) with bias=False require at - # least two values - assert np.isnan(result.values[:11]).all() - assert not np.isnan(result.values[11:]).any() - - # check series of length 0 - empty = Series([], dtype=np.float64) - result = func(empty, empty, 50, min_periods=min_periods) - tm.assert_series_equal(result, empty) - - # check series of length 1 - result = func(Series([1.0]), Series([1.0]), 50, min_periods=min_periods) - tm.assert_series_equal(result, Series([np.NaN])) + @pytest.mark.parametrize("name", ["cov", "corr"]) + def test_different_input_array_raise_exception(self, name, binary_ew_data): + A, _ = binary_ew_data msg = "Input arrays must be of the same type!" # exception raised is Exception with pytest.raises(Exception, match=msg): - func(A, randn(50), 20, min_periods=5) + ew_func(A, randn(50), 20, name=name, min_periods=5) @pytest.mark.slow @pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])