Skip to content

Commit ee36b7a

Browse files
authored
TST: Move some window/moments tests to test_ewm (#44128)
1 parent e2fedf1 commit ee36b7a

File tree

3 files changed

+340
-343
lines changed

3 files changed

+340
-343
lines changed

pandas/tests/window/moments/test_moments_consistency_ewm.py

-52
Original file line numberDiff line numberDiff line change
@@ -18,58 +18,6 @@ def test_ewm_pairwise_cov_corr(func, frame):
1818
tm.assert_series_equal(result, expected, check_names=False)
1919

2020

21-
@pytest.mark.parametrize("name", ["cov", "corr"])
22-
def test_ewm_corr_cov(name):
23-
A = Series(np.random.randn(50), index=np.arange(50))
24-
B = A[2:] + np.random.randn(48)
25-
26-
A[:10] = np.NaN
27-
B[-10:] = np.NaN
28-
29-
result = getattr(A.ewm(com=20, min_periods=5), name)(B)
30-
assert np.isnan(result.values[:14]).all()
31-
assert not np.isnan(result.values[14:]).any()
32-
33-
34-
@pytest.mark.parametrize("min_periods", [0, 1, 2])
35-
@pytest.mark.parametrize("name", ["cov", "corr"])
36-
def test_ewm_corr_cov_min_periods(name, min_periods):
37-
# GH 7898
38-
A = Series(np.random.randn(50), index=np.arange(50))
39-
B = A[2:] + np.random.randn(48)
40-
41-
A[:10] = np.NaN
42-
B[-10:] = np.NaN
43-
44-
result = getattr(A.ewm(com=20, min_periods=min_periods), name)(B)
45-
# binary functions (ewmcov, ewmcorr) with bias=False require at
46-
# least two values
47-
assert np.isnan(result.values[:11]).all()
48-
assert not np.isnan(result.values[11:]).any()
49-
50-
# check series of length 0
51-
empty = Series([], dtype=np.float64)
52-
result = getattr(empty.ewm(com=50, min_periods=min_periods), name)(empty)
53-
tm.assert_series_equal(result, empty)
54-
55-
# check series of length 1
56-
result = getattr(Series([1.0]).ewm(com=50, min_periods=min_periods), name)(
57-
Series([1.0])
58-
)
59-
tm.assert_series_equal(result, Series([np.NaN]))
60-
61-
62-
@pytest.mark.parametrize("name", ["cov", "corr"])
63-
def test_different_input_array_raise_exception(name):
64-
A = Series(np.random.randn(50), index=np.arange(50))
65-
A[:10] = np.NaN
66-
67-
msg = "other must be a DataFrame or Series"
68-
# exception raised is Exception
69-
with pytest.raises(ValueError, match=msg):
70-
getattr(A.ewm(com=20, min_periods=5), name)(np.random.randn(50))
71-
72-
7321
def create_mock_weights(obj, com, adjust, ignore_na):
7422
if isinstance(obj, DataFrame):
7523
if not len(obj.columns):
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import pytest
32

43
from pandas import (
@@ -20,187 +19,6 @@ def test_ewma_frame(frame, name):
2019
assert isinstance(frame_result, DataFrame)
2120

2221

23-
def test_ewma_adjust():
24-
vals = Series(np.zeros(1000))
25-
vals[5] = 1
26-
result = vals.ewm(span=100, adjust=False).mean().sum()
27-
assert np.abs(result - 1) < 1e-2
28-
29-
30-
@pytest.mark.parametrize("adjust", [True, False])
31-
@pytest.mark.parametrize("ignore_na", [True, False])
32-
def test_ewma_cases(adjust, ignore_na):
33-
# try adjust/ignore_na args matrix
34-
35-
s = Series([1.0, 2.0, 4.0, 8.0])
36-
37-
if adjust:
38-
expected = Series([1.0, 1.6, 2.736842, 4.923077])
39-
else:
40-
expected = Series([1.0, 1.333333, 2.222222, 4.148148])
41-
42-
result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean()
43-
tm.assert_series_equal(result, expected)
44-
45-
46-
def test_ewma_nan_handling():
47-
s = Series([1.0] + [np.nan] * 5 + [1.0])
48-
result = s.ewm(com=5).mean()
49-
tm.assert_series_equal(result, Series([1.0] * len(s)))
50-
51-
s = Series([np.nan] * 2 + [1.0] + [np.nan] * 2 + [1.0])
52-
result = s.ewm(com=5).mean()
53-
tm.assert_series_equal(result, Series([np.nan] * 2 + [1.0] * 4))
54-
55-
56-
@pytest.mark.parametrize(
57-
"s, adjust, ignore_na, w",
58-
[
59-
(
60-
Series([np.nan, 1.0, 101.0]),
61-
True,
62-
False,
63-
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
64-
),
65-
(
66-
Series([np.nan, 1.0, 101.0]),
67-
True,
68-
True,
69-
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
70-
),
71-
(
72-
Series([np.nan, 1.0, 101.0]),
73-
False,
74-
False,
75-
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
76-
),
77-
(
78-
Series([np.nan, 1.0, 101.0]),
79-
False,
80-
True,
81-
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
82-
),
83-
(
84-
Series([1.0, np.nan, 101.0]),
85-
True,
86-
False,
87-
[(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, 1.0],
88-
),
89-
(
90-
Series([1.0, np.nan, 101.0]),
91-
True,
92-
True,
93-
[(1.0 - (1.0 / (1.0 + 2.0))), np.nan, 1.0],
94-
),
95-
(
96-
Series([1.0, np.nan, 101.0]),
97-
False,
98-
False,
99-
[(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, (1.0 / (1.0 + 2.0))],
100-
),
101-
(
102-
Series([1.0, np.nan, 101.0]),
103-
False,
104-
True,
105-
[(1.0 - (1.0 / (1.0 + 2.0))), np.nan, (1.0 / (1.0 + 2.0))],
106-
),
107-
(
108-
Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
109-
True,
110-
False,
111-
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))) ** 3, np.nan, np.nan, 1.0, np.nan],
112-
),
113-
(
114-
Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
115-
True,
116-
True,
117-
[np.nan, (1.0 - (1.0 / (1.0 + 2.0))), np.nan, np.nan, 1.0, np.nan],
118-
),
119-
(
120-
Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
121-
False,
122-
False,
123-
[
124-
np.nan,
125-
(1.0 - (1.0 / (1.0 + 2.0))) ** 3,
126-
np.nan,
127-
np.nan,
128-
(1.0 / (1.0 + 2.0)),
129-
np.nan,
130-
],
131-
),
132-
(
133-
Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
134-
False,
135-
True,
136-
[
137-
np.nan,
138-
(1.0 - (1.0 / (1.0 + 2.0))),
139-
np.nan,
140-
np.nan,
141-
(1.0 / (1.0 + 2.0)),
142-
np.nan,
143-
],
144-
),
145-
(
146-
Series([1.0, np.nan, 101.0, 50.0]),
147-
True,
148-
False,
149-
[
150-
(1.0 - (1.0 / (1.0 + 2.0))) ** 3,
151-
np.nan,
152-
(1.0 - (1.0 / (1.0 + 2.0))),
153-
1.0,
154-
],
155-
),
156-
(
157-
Series([1.0, np.nan, 101.0, 50.0]),
158-
True,
159-
True,
160-
[
161-
(1.0 - (1.0 / (1.0 + 2.0))) ** 2,
162-
np.nan,
163-
(1.0 - (1.0 / (1.0 + 2.0))),
164-
1.0,
165-
],
166-
),
167-
(
168-
Series([1.0, np.nan, 101.0, 50.0]),
169-
False,
170-
False,
171-
[
172-
(1.0 - (1.0 / (1.0 + 2.0))) ** 3,
173-
np.nan,
174-
(1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
175-
(1.0 / (1.0 + 2.0))
176-
* ((1.0 - (1.0 / (1.0 + 2.0))) ** 2 + (1.0 / (1.0 + 2.0))),
177-
],
178-
),
179-
(
180-
Series([1.0, np.nan, 101.0, 50.0]),
181-
False,
182-
True,
183-
[
184-
(1.0 - (1.0 / (1.0 + 2.0))) ** 2,
185-
np.nan,
186-
(1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
187-
(1.0 / (1.0 + 2.0)),
188-
],
189-
),
190-
],
191-
)
192-
def test_ewma_nan_handling_cases(s, adjust, ignore_na, w):
193-
# GH 7603
194-
expected = (s.multiply(w).cumsum() / Series(w).cumsum()).fillna(method="ffill")
195-
result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean()
196-
197-
tm.assert_series_equal(result, expected)
198-
if ignore_na is False:
199-
# check that ignore_na defaults to False
200-
result = s.ewm(com=2.0, adjust=adjust).mean()
201-
tm.assert_series_equal(result, expected)
202-
203-
20422
def test_ewma_span_com_args(series):
20523
A = series.ewm(com=9.5).mean()
20624
B = series.ewm(span=20).mean()
@@ -230,22 +48,6 @@ def test_ewma_halflife_arg(series):
23048
series.ewm()
23149

23250

233-
def test_ewm_alpha():
234-
# GH 10789
235-
arr = np.random.randn(100)
236-
locs = np.arange(20, 40)
237-
arr[locs] = np.NaN
238-
239-
s = Series(arr)
240-
a = s.ewm(alpha=0.61722699889169674).mean()
241-
b = s.ewm(com=0.62014947789973052).mean()
242-
c = s.ewm(span=2.240298955799461).mean()
243-
d = s.ewm(halflife=0.721792864318).mean()
244-
tm.assert_series_equal(a, b)
245-
tm.assert_series_equal(a, c)
246-
tm.assert_series_equal(a, d)
247-
248-
24951
def test_ewm_alpha_arg(series):
25052
# GH 10789
25153
s = series
@@ -260,96 +62,3 @@ def test_ewm_alpha_arg(series):
26062
s.ewm(span=10.0, alpha=0.5)
26163
with pytest.raises(ValueError, match=msg):
26264
s.ewm(halflife=10.0, alpha=0.5)
263-
264-
265-
def test_ewm_domain_checks():
266-
# GH 12492
267-
arr = np.random.randn(100)
268-
locs = np.arange(20, 40)
269-
arr[locs] = np.NaN
270-
271-
s = Series(arr)
272-
msg = "comass must satisfy: comass >= 0"
273-
with pytest.raises(ValueError, match=msg):
274-
s.ewm(com=-0.1)
275-
s.ewm(com=0.0)
276-
s.ewm(com=0.1)
277-
278-
msg = "span must satisfy: span >= 1"
279-
with pytest.raises(ValueError, match=msg):
280-
s.ewm(span=-0.1)
281-
with pytest.raises(ValueError, match=msg):
282-
s.ewm(span=0.0)
283-
with pytest.raises(ValueError, match=msg):
284-
s.ewm(span=0.9)
285-
s.ewm(span=1.0)
286-
s.ewm(span=1.1)
287-
288-
msg = "halflife must satisfy: halflife > 0"
289-
with pytest.raises(ValueError, match=msg):
290-
s.ewm(halflife=-0.1)
291-
with pytest.raises(ValueError, match=msg):
292-
s.ewm(halflife=0.0)
293-
s.ewm(halflife=0.1)
294-
295-
msg = "alpha must satisfy: 0 < alpha <= 1"
296-
with pytest.raises(ValueError, match=msg):
297-
s.ewm(alpha=-0.1)
298-
with pytest.raises(ValueError, match=msg):
299-
s.ewm(alpha=0.0)
300-
s.ewm(alpha=0.1)
301-
s.ewm(alpha=1.0)
302-
with pytest.raises(ValueError, match=msg):
303-
s.ewm(alpha=1.1)
304-
305-
306-
@pytest.mark.parametrize("method", ["mean", "std", "var"])
307-
def test_ew_empty_series(method):
308-
vals = Series([], dtype=np.float64)
309-
310-
ewm = vals.ewm(3)
311-
result = getattr(ewm, method)()
312-
tm.assert_almost_equal(result, vals)
313-
314-
315-
@pytest.mark.parametrize("min_periods", [0, 1])
316-
@pytest.mark.parametrize("name", ["mean", "var", "std"])
317-
def test_ew_min_periods(min_periods, name):
318-
# excluding NaNs correctly
319-
arr = np.random.randn(50)
320-
arr[:10] = np.NaN
321-
arr[-10:] = np.NaN
322-
s = Series(arr)
323-
324-
# check min_periods
325-
# GH 7898
326-
result = getattr(s.ewm(com=50, min_periods=2), name)()
327-
assert result[:11].isna().all()
328-
assert not result[11:].isna().any()
329-
330-
result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
331-
if name == "mean":
332-
assert result[:10].isna().all()
333-
assert not result[10:].isna().any()
334-
else:
335-
# ewm.std, ewm.var (with bias=False) require at least
336-
# two values
337-
assert result[:11].isna().all()
338-
assert not result[11:].isna().any()
339-
340-
# check series of length 0
341-
result = getattr(Series(dtype=object).ewm(com=50, min_periods=min_periods), name)()
342-
tm.assert_series_equal(result, Series(dtype="float64"))
343-
344-
# check series of length 1
345-
result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
346-
if name == "mean":
347-
tm.assert_series_equal(result, Series([1.0]))
348-
else:
349-
# ewm.std, ewm.var with bias=False require at least
350-
# two values
351-
tm.assert_series_equal(result, Series([np.NaN]))
352-
353-
# pass in ints
354-
result2 = getattr(Series(np.arange(50)).ewm(span=10), name)()
355-
assert result2.dtype == np.float_

0 commit comments

Comments
 (0)