@@ -863,17 +863,22 @@ def get_op_from_name(self, op_name):
863
863
short_opname = op_name .strip ("_" )
864
864
if short_opname == "rtruediv" :
865
865
# use the numpy version that won't raise on division by zero
866
- return lambda x , y : np .divide (y , x )
866
+
867
+ def rtruediv (x , y ):
868
+ return np .divide (y , x )
869
+
870
+ return rtruediv
867
871
elif short_opname == "rfloordiv" :
868
872
return lambda x , y : np .floor_divide (y , x )
869
873
870
874
return tm .get_op_from_name (op_name )
871
875
872
- def _patch_combine (self , obj , other , op ):
876
+ def _combine (self , obj , other , op ):
873
877
# BaseOpsUtil._combine can upcast expected dtype
874
878
# (because it generates expected on python scalars)
875
879
# while ArrowExtensionArray maintains original type
876
880
expected = base .BaseArithmeticOpsTests ._combine (self , obj , other , op )
881
+
877
882
was_frame = False
878
883
if isinstance (expected , pd .DataFrame ):
879
884
was_frame = True
@@ -883,10 +888,37 @@ def _patch_combine(self, obj, other, op):
883
888
expected_data = expected
884
889
original_dtype = obj .dtype
885
890
891
+ orig_pa_type = original_dtype .pyarrow_dtype
892
+ if not was_frame and isinstance (other , pd .Series ):
893
+ # i.e. test_arith_series_with_array
894
+ if not (
895
+ pa .types .is_floating (orig_pa_type )
896
+ or (
897
+ pa .types .is_integer (orig_pa_type )
898
+ and op .__name__ not in ["truediv" , "rtruediv" ]
899
+ )
900
+ or pa .types .is_duration (orig_pa_type )
901
+ or pa .types .is_timestamp (orig_pa_type )
902
+ or pa .types .is_date (orig_pa_type )
903
+ or pa .types .is_decimal (orig_pa_type )
904
+ ):
905
+ # base class _combine always returns int64, while
906
+ # ArrowExtensionArray does not upcast
907
+ return expected
908
+ elif not (
909
+ (op is operator .floordiv and pa .types .is_integer (orig_pa_type ))
910
+ or pa .types .is_duration (orig_pa_type )
911
+ or pa .types .is_timestamp (orig_pa_type )
912
+ or pa .types .is_date (orig_pa_type )
913
+ or pa .types .is_decimal (orig_pa_type )
914
+ ):
915
+ # base class _combine always returns int64, while
916
+ # ArrowExtensionArray does not upcast
917
+ return expected
918
+
886
919
pa_expected = pa .array (expected_data ._values )
887
920
888
921
if pa .types .is_duration (pa_expected .type ):
889
- orig_pa_type = original_dtype .pyarrow_dtype
890
922
if pa .types .is_date (orig_pa_type ):
891
923
if pa .types .is_date64 (orig_pa_type ):
892
924
# TODO: why is this different vs date32?
@@ -907,7 +939,7 @@ def _patch_combine(self, obj, other, op):
907
939
pa_expected = pa_expected .cast (f"duration[{ unit } ]" )
908
940
909
941
elif pa .types .is_decimal (pa_expected .type ) and pa .types .is_decimal (
910
- original_dtype . pyarrow_dtype
942
+ orig_pa_type
911
943
):
912
944
# decimal precision can resize in the result type depending on data
913
945
# just compare the float values
@@ -929,7 +961,7 @@ def _patch_combine(self, obj, other, op):
929
961
return expected .astype (alt_dtype )
930
962
931
963
else :
932
- pa_expected = pa_expected .cast (original_dtype . pyarrow_dtype )
964
+ pa_expected = pa_expected .cast (orig_pa_type )
933
965
934
966
pd_expected = type (expected_data ._values )(pa_expected )
935
967
if was_frame :
@@ -1043,9 +1075,7 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
1043
1075
1044
1076
return mark
1045
1077
1046
- def test_arith_series_with_scalar (
1047
- self , data , all_arithmetic_operators , request , monkeypatch
1048
- ):
1078
+ def test_arith_series_with_scalar (self , data , all_arithmetic_operators , request ):
1049
1079
pa_dtype = data .dtype .pyarrow_dtype
1050
1080
1051
1081
if all_arithmetic_operators == "__rmod__" and (
@@ -1061,24 +1091,9 @@ def test_arith_series_with_scalar(
1061
1091
if mark is not None :
1062
1092
request .node .add_marker (mark )
1063
1093
1064
- if (
1065
- (
1066
- all_arithmetic_operators == "__floordiv__"
1067
- and pa .types .is_integer (pa_dtype )
1068
- )
1069
- or pa .types .is_duration (pa_dtype )
1070
- or pa .types .is_timestamp (pa_dtype )
1071
- or pa .types .is_date (pa_dtype )
1072
- or pa .types .is_decimal (pa_dtype )
1073
- ):
1074
- # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
1075
- # not upcast
1076
- monkeypatch .setattr (TestBaseArithmeticOps , "_combine" , self ._patch_combine )
1077
1094
super ().test_arith_series_with_scalar (data , all_arithmetic_operators )
1078
1095
1079
- def test_arith_frame_with_scalar (
1080
- self , data , all_arithmetic_operators , request , monkeypatch
1081
- ):
1096
+ def test_arith_frame_with_scalar (self , data , all_arithmetic_operators , request ):
1082
1097
pa_dtype = data .dtype .pyarrow_dtype
1083
1098
1084
1099
if all_arithmetic_operators == "__rmod__" and (
@@ -1094,24 +1109,9 @@ def test_arith_frame_with_scalar(
1094
1109
if mark is not None :
1095
1110
request .node .add_marker (mark )
1096
1111
1097
- if (
1098
- (
1099
- all_arithmetic_operators == "__floordiv__"
1100
- and pa .types .is_integer (pa_dtype )
1101
- )
1102
- or pa .types .is_duration (pa_dtype )
1103
- or pa .types .is_timestamp (pa_dtype )
1104
- or pa .types .is_date (pa_dtype )
1105
- or pa .types .is_decimal (pa_dtype )
1106
- ):
1107
- # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
1108
- # not upcast
1109
- monkeypatch .setattr (TestBaseArithmeticOps , "_combine" , self ._patch_combine )
1110
1112
super ().test_arith_frame_with_scalar (data , all_arithmetic_operators )
1111
1113
1112
- def test_arith_series_with_array (
1113
- self , data , all_arithmetic_operators , request , monkeypatch
1114
- ):
1114
+ def test_arith_series_with_array (self , data , all_arithmetic_operators , request ):
1115
1115
pa_dtype = data .dtype .pyarrow_dtype
1116
1116
1117
1117
self .series_array_exc = self ._get_scalar_exception (
@@ -1147,18 +1147,6 @@ def test_arith_series_with_array(
1147
1147
# since ser.iloc[0] is a python scalar
1148
1148
other = pd .Series (pd .array ([ser .iloc [0 ]] * len (ser ), dtype = data .dtype ))
1149
1149
1150
- if (
1151
- pa .types .is_floating (pa_dtype )
1152
- or (
1153
- pa .types .is_integer (pa_dtype )
1154
- and all_arithmetic_operators not in ["__truediv__" , "__rtruediv__" ]
1155
- )
1156
- or pa .types .is_duration (pa_dtype )
1157
- or pa .types .is_timestamp (pa_dtype )
1158
- or pa .types .is_date (pa_dtype )
1159
- or pa .types .is_decimal (pa_dtype )
1160
- ):
1161
- monkeypatch .setattr (TestBaseArithmeticOps , "_combine" , self ._patch_combine )
1162
1150
self .check_opname (ser , op_name , other , exc = self .series_array_exc )
1163
1151
1164
1152
def test_add_series_with_extension_array (self , data , request ):
0 commit comments