Skip to content

Commit 7f6ea67

Browse files
Dr-Irvjreback
authored andcommitted
BUG: Series.combine() fails with ExtensionArray inside of Series (#21183)
1 parent 4f1704e commit 7f6ea67

File tree

9 files changed

+135
-6
lines changed

9 files changed

+135
-6
lines changed

doc/source/whatsnew/v0.24.0.txt

+9
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,18 @@ Reshaping
179179
-
180180
-
181181

182+
ExtensionArray
183+
^^^^^^^^^^^^^^
184+
185+
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
186+
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
187+
-
188+
-
189+
182190
Other
183191
^^^^^
184192

185193
- :meth: `~pandas.io.formats.style.Styler.background_gradient` now takes a ``text_color_threshold`` parameter to automatically lighten the text color based on the luminance of the background color. This improves readability with dark background colors without the need to limit the background colormap range. (:issue:`21258`)
186194
-
187195
-
196+
-

pandas/core/series.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -2204,7 +2204,7 @@ def _binop(self, other, func, level=None, fill_value=None):
22042204
result.name = None
22052205
return result
22062206

2207-
def combine(self, other, func, fill_value=np.nan):
2207+
def combine(self, other, func, fill_value=None):
22082208
"""
22092209
Perform elementwise binary operation on two Series using given function
22102210
with optional fill value when an index is missing from one Series or
@@ -2216,6 +2216,8 @@ def combine(self, other, func, fill_value=np.nan):
22162216
func : function
22172217
Function that takes two scalars as inputs and return a scalar
22182218
fill_value : scalar value
2219+
The default specifies to use the appropriate NaN value for
2220+
the underlying dtype of the Series
22192221
22202222
Returns
22212223
-------
@@ -2235,20 +2237,38 @@ def combine(self, other, func, fill_value=np.nan):
22352237
Series.combine_first : Combine Series values, choosing the calling
22362238
Series's values first
22372239
"""
2240+
if fill_value is None:
2241+
fill_value = na_value_for_dtype(self.dtype, compat=False)
2242+
22382243
if isinstance(other, Series):
2244+
# If other is a Series, result is based on union of Series,
2245+
# so do this element by element
22392246
new_index = self.index.union(other.index)
22402247
new_name = ops.get_op_result_name(self, other)
2241-
new_values = np.empty(len(new_index), dtype=self.dtype)
2242-
for i, idx in enumerate(new_index):
2248+
new_values = []
2249+
for idx in new_index:
22432250
lv = self.get(idx, fill_value)
22442251
rv = other.get(idx, fill_value)
22452252
with np.errstate(all='ignore'):
2246-
new_values[i] = func(lv, rv)
2253+
new_values.append(func(lv, rv))
22472254
else:
2255+
# Assume that other is a scalar, so apply the function for
2256+
# each element in the Series
22482257
new_index = self.index
22492258
with np.errstate(all='ignore'):
2250-
new_values = func(self._values, other)
2259+
new_values = [func(lv, other) for lv in self._values]
22512260
new_name = self.name
2261+
2262+
if is_categorical_dtype(self.values):
2263+
pass
2264+
elif is_extension_array_dtype(self.values):
2265+
# The function can return something of any type, so check
2266+
# if the type is compatible with the calling EA
2267+
try:
2268+
new_values = self._values._from_sequence(new_values)
2269+
except TypeError:
2270+
pass
2271+
22522272
return self._constructor(new_values, index=new_index, name=new_name)
22532273

22542274
def combine_first(self, other):

pandas/tests/extension/base/methods.py

+34
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,37 @@ def test_factorize_equivalence(self, data_for_grouping, na_sentinel):
103103

104104
tm.assert_numpy_array_equal(l1, l2)
105105
self.assert_extension_array_equal(u1, u2)
106+
107+
def test_combine_le(self, data_repeated):
108+
# GH 20825
109+
# Test that combine works when doing a <= (le) comparison
110+
orig_data1, orig_data2 = data_repeated(2)
111+
s1 = pd.Series(orig_data1)
112+
s2 = pd.Series(orig_data2)
113+
result = s1.combine(s2, lambda x1, x2: x1 <= x2)
114+
expected = pd.Series([a <= b for (a, b) in
115+
zip(list(orig_data1), list(orig_data2))])
116+
self.assert_series_equal(result, expected)
117+
118+
val = s1.iloc[0]
119+
result = s1.combine(val, lambda x1, x2: x1 <= x2)
120+
expected = pd.Series([a <= val for a in list(orig_data1)])
121+
self.assert_series_equal(result, expected)
122+
123+
def test_combine_add(self, data_repeated):
124+
# GH 20825
125+
orig_data1, orig_data2 = data_repeated(2)
126+
s1 = pd.Series(orig_data1)
127+
s2 = pd.Series(orig_data2)
128+
result = s1.combine(s2, lambda x1, x2: x1 + x2)
129+
expected = pd.Series(
130+
orig_data1._from_sequence([a + b for (a, b) in
131+
zip(list(orig_data1),
132+
list(orig_data2))]))
133+
self.assert_series_equal(result, expected)
134+
135+
val = s1.iloc[0]
136+
result = s1.combine(val, lambda x1, x2: x1 + x2)
137+
expected = pd.Series(
138+
orig_data1._from_sequence([a + val for a in list(orig_data1)]))
139+
self.assert_series_equal(result, expected)

pandas/tests/extension/category/test_categorical.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import string
22

33
import pytest
4+
import pandas as pd
45
import numpy as np
56

67
from pandas.api.types import CategoricalDtype
@@ -29,6 +30,15 @@ def data_missing():
2930
return Categorical([np.nan, 'A'])
3031

3132

33+
@pytest.fixture
34+
def data_repeated():
35+
"""Return different versions of data for count times"""
36+
def gen(count):
37+
for _ in range(count):
38+
yield Categorical(make_data())
39+
yield gen
40+
41+
3242
@pytest.fixture
3343
def data_for_sorting():
3444
return Categorical(['A', 'B', 'C'], categories=['C', 'A', 'B'],
@@ -154,6 +164,22 @@ class TestMethods(base.BaseMethodsTests):
154164
def test_value_counts(self, all_data, dropna):
155165
pass
156166

167+
def test_combine_add(self, data_repeated):
168+
# GH 20825
169+
# When adding categoricals in combine, result is a string
170+
orig_data1, orig_data2 = data_repeated(2)
171+
s1 = pd.Series(orig_data1)
172+
s2 = pd.Series(orig_data2)
173+
result = s1.combine(s2, lambda x1, x2: x1 + x2)
174+
expected = pd.Series(([a + b for (a, b) in
175+
zip(list(orig_data1), list(orig_data2))]))
176+
self.assert_series_equal(result, expected)
177+
178+
val = s1.iloc[0]
179+
result = s1.combine(val, lambda x1, x2: x1 + x2)
180+
expected = pd.Series([a + val for a in list(orig_data1)])
181+
self.assert_series_equal(result, expected)
182+
157183

158184
class TestCasting(base.BaseCastingTests):
159185
pass

pandas/tests/extension/conftest.py

+9
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ def all_data(request, data, data_missing):
3030
return data_missing
3131

3232

33+
@pytest.fixture
34+
def data_repeated():
35+
"""Return different versions of data for count times"""
36+
def gen(count):
37+
for _ in range(count):
38+
yield NotImplementedError
39+
yield gen
40+
41+
3342
@pytest.fixture
3443
def data_for_sorting():
3544
"""Length-3 array with a known sort order.

pandas/tests/extension/decimal/array.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class DecimalArray(ExtensionArray):
2828
dtype = DecimalDtype()
2929

3030
def __init__(self, values):
31-
assert all(isinstance(v, decimal.Decimal) for v in values)
31+
for val in values:
32+
if not isinstance(val, self.dtype.type):
33+
raise TypeError
3234
values = np.asarray(values, dtype=object)
3335

3436
self._data = values

pandas/tests/extension/decimal/test_decimal.py

+8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ def data_missing():
2525
return DecimalArray([decimal.Decimal('NaN'), decimal.Decimal(1)])
2626

2727

28+
@pytest.fixture
29+
def data_repeated():
30+
def gen(count):
31+
for _ in range(count):
32+
yield DecimalArray(make_data())
33+
yield gen
34+
35+
2836
@pytest.fixture
2937
def data_for_sorting():
3038
return DecimalArray([decimal.Decimal('1'),

pandas/tests/extension/json/test_json.py

+8
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,14 @@ def test_sort_values_missing(self, data_missing_for_sorting, ascending):
187187
super(TestMethods, self).test_sort_values_missing(
188188
data_missing_for_sorting, ascending)
189189

190+
@pytest.mark.skip(reason="combine for JSONArray not supported")
191+
def test_combine_le(self, data_repeated):
192+
pass
193+
194+
@pytest.mark.skip(reason="combine for JSONArray not supported")
195+
def test_combine_add(self, data_repeated):
196+
pass
197+
190198

191199
class TestCasting(BaseJSON, base.BaseCastingTests):
192200
@pytest.mark.xfail

pandas/tests/series/test_combine_concat.py

+13
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,19 @@ def test_append_duplicates(self):
6060
with tm.assert_raises_regex(ValueError, msg):
6161
pd.concat([s1, s2], verify_integrity=True)
6262

63+
def test_combine_scalar(self):
64+
# GH 21248
65+
# Note - combine() with another Series is tested elsewhere because
66+
# it is used when testing operators
67+
s = pd.Series([i * 10 for i in range(5)])
68+
result = s.combine(3, lambda x, y: x + y)
69+
expected = pd.Series([i * 10 + 3 for i in range(5)])
70+
tm.assert_series_equal(result, expected)
71+
72+
result = s.combine(22, lambda x, y: min(x, y))
73+
expected = pd.Series([min(i * 10, 22) for i in range(5)])
74+
tm.assert_series_equal(result, expected)
75+
6376
def test_combine_first(self):
6477
values = tm.makeIntIndex(20).values.astype(float)
6578
series = Series(values, index=tm.makeIntIndex(20))

0 commit comments

Comments
 (0)