Skip to content

Commit 1553ec3

Browse files
authored
REF: avoid monkeypatch in arrow tests (#54361)
* REF: avoid monkeypatch in arrow tests * cleanup
1 parent 9db9baa commit 1553ec3

File tree

1 file changed

+40
-52
lines changed

1 file changed

+40
-52
lines changed

pandas/tests/extension/test_arrow.py

Lines changed: 40 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -863,17 +863,22 @@ def get_op_from_name(self, op_name):
863863
short_opname = op_name.strip("_")
864864
if short_opname == "rtruediv":
865865
# use the numpy version that won't raise on division by zero
866-
return lambda x, y: np.divide(y, x)
866+
867+
def rtruediv(x, y):
868+
return np.divide(y, x)
869+
870+
return rtruediv
867871
elif short_opname == "rfloordiv":
868872
return lambda x, y: np.floor_divide(y, x)
869873

870874
return tm.get_op_from_name(op_name)
871875

872-
def _patch_combine(self, obj, other, op):
876+
def _combine(self, obj, other, op):
873877
# BaseOpsUtil._combine can upcast expected dtype
874878
# (because it generates expected on python scalars)
875879
# while ArrowExtensionArray maintains original type
876880
expected = base.BaseArithmeticOpsTests._combine(self, obj, other, op)
881+
877882
was_frame = False
878883
if isinstance(expected, pd.DataFrame):
879884
was_frame = True
@@ -883,10 +888,37 @@ def _patch_combine(self, obj, other, op):
883888
expected_data = expected
884889
original_dtype = obj.dtype
885890

891+
orig_pa_type = original_dtype.pyarrow_dtype
892+
if not was_frame and isinstance(other, pd.Series):
893+
# i.e. test_arith_series_with_array
894+
if not (
895+
pa.types.is_floating(orig_pa_type)
896+
or (
897+
pa.types.is_integer(orig_pa_type)
898+
and op.__name__ not in ["truediv", "rtruediv"]
899+
)
900+
or pa.types.is_duration(orig_pa_type)
901+
or pa.types.is_timestamp(orig_pa_type)
902+
or pa.types.is_date(orig_pa_type)
903+
or pa.types.is_decimal(orig_pa_type)
904+
):
905+
# base class _combine always returns int64, while
906+
# ArrowExtensionArray does not upcast
907+
return expected
908+
elif not (
909+
(op is operator.floordiv and pa.types.is_integer(orig_pa_type))
910+
or pa.types.is_duration(orig_pa_type)
911+
or pa.types.is_timestamp(orig_pa_type)
912+
or pa.types.is_date(orig_pa_type)
913+
or pa.types.is_decimal(orig_pa_type)
914+
):
915+
# base class _combine always returns int64, while
916+
# ArrowExtensionArray does not upcast
917+
return expected
918+
886919
pa_expected = pa.array(expected_data._values)
887920

888921
if pa.types.is_duration(pa_expected.type):
889-
orig_pa_type = original_dtype.pyarrow_dtype
890922
if pa.types.is_date(orig_pa_type):
891923
if pa.types.is_date64(orig_pa_type):
892924
# TODO: why is this different vs date32?
@@ -907,7 +939,7 @@ def _patch_combine(self, obj, other, op):
907939
pa_expected = pa_expected.cast(f"duration[{unit}]")
908940

909941
elif pa.types.is_decimal(pa_expected.type) and pa.types.is_decimal(
910-
original_dtype.pyarrow_dtype
942+
orig_pa_type
911943
):
912944
# decimal precision can resize in the result type depending on data
913945
# just compare the float values
@@ -929,7 +961,7 @@ def _patch_combine(self, obj, other, op):
929961
return expected.astype(alt_dtype)
930962

931963
else:
932-
pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype)
964+
pa_expected = pa_expected.cast(orig_pa_type)
933965

934966
pd_expected = type(expected_data._values)(pa_expected)
935967
if was_frame:
@@ -1043,9 +1075,7 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
10431075

10441076
return mark
10451077

1046-
def test_arith_series_with_scalar(
1047-
self, data, all_arithmetic_operators, request, monkeypatch
1048-
):
1078+
def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request):
10491079
pa_dtype = data.dtype.pyarrow_dtype
10501080

10511081
if all_arithmetic_operators == "__rmod__" and (
@@ -1061,24 +1091,9 @@ def test_arith_series_with_scalar(
10611091
if mark is not None:
10621092
request.node.add_marker(mark)
10631093

1064-
if (
1065-
(
1066-
all_arithmetic_operators == "__floordiv__"
1067-
and pa.types.is_integer(pa_dtype)
1068-
)
1069-
or pa.types.is_duration(pa_dtype)
1070-
or pa.types.is_timestamp(pa_dtype)
1071-
or pa.types.is_date(pa_dtype)
1072-
or pa.types.is_decimal(pa_dtype)
1073-
):
1074-
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
1075-
# not upcast
1076-
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
10771094
super().test_arith_series_with_scalar(data, all_arithmetic_operators)
10781095

1079-
def test_arith_frame_with_scalar(
1080-
self, data, all_arithmetic_operators, request, monkeypatch
1081-
):
1096+
def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
10821097
pa_dtype = data.dtype.pyarrow_dtype
10831098

10841099
if all_arithmetic_operators == "__rmod__" and (
@@ -1094,24 +1109,9 @@ def test_arith_frame_with_scalar(
10941109
if mark is not None:
10951110
request.node.add_marker(mark)
10961111

1097-
if (
1098-
(
1099-
all_arithmetic_operators == "__floordiv__"
1100-
and pa.types.is_integer(pa_dtype)
1101-
)
1102-
or pa.types.is_duration(pa_dtype)
1103-
or pa.types.is_timestamp(pa_dtype)
1104-
or pa.types.is_date(pa_dtype)
1105-
or pa.types.is_decimal(pa_dtype)
1106-
):
1107-
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
1108-
# not upcast
1109-
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
11101112
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
11111113

1112-
def test_arith_series_with_array(
1113-
self, data, all_arithmetic_operators, request, monkeypatch
1114-
):
1114+
def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
11151115
pa_dtype = data.dtype.pyarrow_dtype
11161116

11171117
self.series_array_exc = self._get_scalar_exception(
@@ -1147,18 +1147,6 @@ def test_arith_series_with_array(
11471147
# since ser.iloc[0] is a python scalar
11481148
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
11491149

1150-
if (
1151-
pa.types.is_floating(pa_dtype)
1152-
or (
1153-
pa.types.is_integer(pa_dtype)
1154-
and all_arithmetic_operators not in ["__truediv__", "__rtruediv__"]
1155-
)
1156-
or pa.types.is_duration(pa_dtype)
1157-
or pa.types.is_timestamp(pa_dtype)
1158-
or pa.types.is_date(pa_dtype)
1159-
or pa.types.is_decimal(pa_dtype)
1160-
):
1161-
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
11621150
self.check_opname(ser, op_name, other, exc=self.series_array_exc)
11631151

11641152
def test_add_series_with_extension_array(self, data, request):

0 commit comments

Comments
 (0)