Skip to content

Commit

Permalink
API/TST: expand tests for string any/all reduction + fix pyarrow-base…
Browse files Browse the repository at this point in the history
…d implementation (pandas-dev#59414)
  • Loading branch information
jorisvandenbossche committed Oct 7, 2024
1 parent 2465a6d commit 35ebe68
Showing 1 changed file with 44 additions and 7 deletions.
51 changes: 44 additions & 7 deletions pandas/tests/reductions/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,25 +1091,62 @@ def test_any_all_datetimelike(self):
assert df.any().all()
assert not df.all().any()

def test_any_all_pyarrow_string(self):
def test_any_all_string_dtype(self, any_string_dtype):
# GH#54591
pytest.importorskip("pyarrow")
ser = Series(["", "a"], dtype="string[pyarrow_numpy]")
if (
isinstance(any_string_dtype, pd.StringDtype)
and any_string_dtype.na_value is pd.NA
):
# the nullable string dtype currently still raise an error
# https://github.com/pandas-dev/pandas/issues/51939
ser = Series(["a", "b"], dtype=any_string_dtype)
with pytest.raises(TypeError):
ser.any()
with pytest.raises(TypeError):
ser.all()
return

ser = Series(["", "a"], dtype=any_string_dtype)
assert ser.any()
assert not ser.all()
assert ser.any(skipna=False)
assert not ser.all(skipna=False)

ser = Series([None, "a"], dtype="string[pyarrow_numpy]")
ser = Series([np.nan, "a"], dtype=any_string_dtype)
assert ser.any()
assert ser.all()
assert not ser.all(skipna=False)
assert ser.any(skipna=False)
assert ser.all(skipna=False) # NaN is considered truthy

ser = Series([None, ""], dtype="string[pyarrow_numpy]")
ser = Series([np.nan, ""], dtype=any_string_dtype)
assert not ser.any()
assert not ser.all()
assert ser.any(skipna=False) # NaN is considered truthy
assert not ser.all(skipna=False)

ser = Series(["a", "b"], dtype="string[pyarrow_numpy]")
ser = Series(["a", "b"], dtype=any_string_dtype)
assert ser.any()
assert ser.all()
assert ser.any(skipna=False)
assert ser.all(skipna=False)

ser = Series([], dtype=any_string_dtype)
assert not ser.any()
assert ser.all()
assert not ser.any(skipna=False)
assert ser.all(skipna=False)

ser = Series([""], dtype=any_string_dtype)
assert not ser.any()
assert not ser.all()
assert not ser.any(skipna=False)
assert not ser.all(skipna=False)

ser = Series([np.nan], dtype=any_string_dtype)
assert not ser.any()
assert ser.all()
assert ser.any(skipna=False) # NaN is considered truthy
assert ser.all(skipna=False) # NaN is considered truthy

def test_timedelta64_analytics(self):
# index min/max
Expand Down

0 comments on commit 35ebe68

Please sign in to comment.