Skip to content

Commit bfd99d3

Browse files
committed
coerce_extension
1 parent 7940296 commit bfd99d3

File tree

4 files changed

+49
-12
lines changed

4 files changed

+49
-12
lines changed

pandas/core/frame.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_possibly_downcast_to_dtype,
3131
_invalidate_string_dtypes,
3232
_coerce_to_dtypes,
33+
_coerce_extension_to_embed,
3334
_maybe_upcast_putmask,
3435
_find_common_type)
3536
from pandas.types.common import (is_categorical_dtype,
@@ -2647,7 +2648,7 @@ def reindexer(value):
26472648

26482649
# return internal types directly
26492650
if is_extension_type(value):
2650-
return value
2651+
return _coerce_extension_to_embed(value)
26512652

26522653
# broadcast across multiple columns if necessary
26532654
if broadcast and key in self.columns and value.ndim == 1:

pandas/tests/frame/test_alter_axes.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from pandas.compat import lrange
1010
from pandas import (DataFrame, Series, Index, MultiIndex,
11-
RangeIndex, date_range)
11+
RangeIndex, date_range, IntervalIndex)
1212
import pandas as pd
1313

1414
from pandas.util.testing import (assert_series_equal,
@@ -718,11 +718,33 @@ def test_set_index_preserve_categorical_dtype(self):
718718

719719
class TestIntervalIndex(tm.TestCase):
720720

721-
def test_set_reset(self):
721+
def test_setitem(self):
722+
723+
df = DataFrame({'A': range(10)})
724+
s = pd.cut(df.A, 5)
725+
self.assertIsInstance(s.cat.categories, IntervalIndex)
726+
727+
# these should end up the same, namely
728+
# an object array of Intervals
729+
df['B'] = s
730+
df['C'] = np.array(s)
731+
df['D'] = s.values
732+
df['E'] = np.array(s.values)
733+
734+
self.assertTrue(df['B'].dtype == 'object')
735+
self.assertTrue(df['C'].dtype == 'object')
736+
self.assertTrue(df['D'].dtype == 'object')
737+
self.assertTrue(df['E'].dtype == 'object')
738+
739+
tm.assert_series_equal(df['B'], df['C'], check_names=False)
740+
tm.assert_series_equal(df['B'], df['D'], check_names=False)
741+
tm.assert_series_equal(df['B'], df['E'], check_names=False)
742+
743+
def test_set_reset_index(self):
744+
722745
df = DataFrame({'A': range(10)})
723-
df['B'] = pd.cut(df.A, 5)
746+
s = pd.cut(df.A, 5)
747+
df['B'] = s
724748
df = df.set_index('B')
725749

726-
# TODO: this should actually be converted prior
727-
self.assertTrue(isinstance(df.index, pd.CategoricalIndex))
728750
df = df.reset_index()

pandas/tests/indexing/test_interval.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def test_loc_getitem_series(self):
3333
def test_loc_getitem_frame(self):
3434

3535
df = DataFrame({'A': range(10)})
36-
df['B'] = pd.cut(df.A, 5)
36+
s = pd.cut(df.A, 5)
37+
df['B'] = s
3738
df = df.set_index('B')
3839

39-
# TODO: fixme
40-
# result = df.loc[4]
41-
# expected = df.iloc[3]
42-
# tm.assert_series_equal(result, expected)
40+
result = df.loc[4]
41+
expected = df.iloc[4:6]
42+
tm.assert_frame_equal(result, expected)

pandas/types/cast.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pandas.compat import string_types, text_type, PY3
88
from .common import (_ensure_object, is_bool, is_integer, is_float,
99
is_complex, is_datetimetz, is_categorical_dtype,
10-
is_datetimelike,
10+
is_datetimelike, is_interval_dtype,
1111
is_extension_type, is_object_dtype,
1212
is_datetime64tz_dtype, is_datetime64_dtype,
1313
is_timedelta64_dtype, is_dtype_equal,
@@ -484,6 +484,20 @@ def conv(r, dtype):
484484
return [conv(r, dtype) for r, dtype in zip(result, dtypes)]
485485

486486

487+
def _coerce_extension_to_embed(value):
488+
"""
489+
we have an extension type, coerce it to a type
490+
suitable for embedding (in a Series/DataFrame)
491+
"""
492+
493+
# TODO: maybe we should have a method on Categorical
494+
# to actually do this instead
495+
if is_categorical_dtype(value):
496+
if is_interval_dtype(value.categories):
497+
return np.array(value)
498+
499+
return value
500+
487501
def _astype_nansafe(arr, dtype, copy=True):
488502
""" return a view if copy is False, but
489503
need to be very careful as the result shape could change! """

0 commit comments

Comments
 (0)