Skip to content

API: Index.take inconsistently handle fill_value #12676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion doc/source/whatsnew/v0.18.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ New features
Enhancements
~~~~~~~~~~~~


.. _whatsnew_0181.partial_string_indexing:

Partial string indexing on ``DateTimeIndex`` when part of a ``MultiIndex``
Expand Down Expand Up @@ -59,6 +58,14 @@ Other Enhancements
- ``pd.read_csv()`` now supports opening ZIP files that contains a single CSV, via extension inference or explict ``compression='zip'`` (:issue:`12175`)
- ``pd.read_csv()`` now supports opening files using xz compression, via extension inference or explicit ``compression='xz'`` is specified; ``xz`` compressions is also supported by ``DataFrame.to_csv`` in the same way (:issue:`11852`)
- ``pd.read_msgpack()`` now always gives writeable ndarrays even when compression is used (:issue:`12359`).
- ``Index.take`` now handles ``allow_fill`` and ``fill_value`` consistently (:issue:`12631`)

.. ipython:: python

idx = pd.Index([1., 2., 3., 4.], dtype='float')
idx.take([2, -1]) # default, allow_fill=True, fill_value=None
idx.take([2, -1], fill_value=True)


.. _whatsnew_0181.api:

Expand Down
50 changes: 43 additions & 7 deletions pandas/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,24 +1329,60 @@ def _ensure_compat_concat(indexes):

return indexes

def take(self, indices, axis=0, allow_fill=True, fill_value=None):
"""
return a new Index of the values selected by the indexer
_index_shared_docs['take'] = """
return a new Index of the values selected by the indices

For internal compatibility with numpy arrays.

# filling must always be None/nan here
# but is passed thru internally
Parameters
----------
indices : list
Indices to be taken
axis : int, optional
The axis over which to select values, always 0.
allow_fill : bool, default True
fill_value : bool, default None
If allow_fill=True and fill_value is not None, indices specified by
-1 is regarded as NA. If Index doesn't hold NA, raise ValueError

See also
--------
numpy.ndarray.take
"""

@Appender(_index_shared_docs['take'])
def take(self, indices, axis=0, allow_fill=True, fill_value=None):
indices = com._ensure_platform_int(indices)
taken = self.values.take(indices)
if self._can_hold_na:
taken = self._assert_take_fillable(self.values, indices,
allow_fill=allow_fill,
fill_value=fill_value,
na_value=self._na_value)
else:
if allow_fill and fill_value is not None:
msg = 'Unable to fill values because {0} cannot contain NA'
raise ValueError(msg.format(self.__class__.__name__))
taken = self.values.take(indices)
return self._shallow_copy(taken)

def _assert_take_fillable(self, values, indices, allow_fill=True,
fill_value=None, na_value=np.nan):
""" Internal method to handle NA filling of take """
indices = com._ensure_platform_int(indices)

# only fill if we are passing a non-None fill_value
if allow_fill and fill_value is not None:
if (indices < -1).any():
msg = ('When allow_fill=True and fill_value is not None, '
'all indices must be >= -1')
raise ValueError(msg)
taken = values.take(indices)
mask = indices == -1
if mask.any():
taken[mask] = na_value
else:
taken = values.take(indices)
return taken

@cache_readonly
def _isnan(self):
""" return if each value is nan"""
Expand Down
22 changes: 7 additions & 15 deletions pandas/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,21 +459,13 @@ def _convert_list_indexer(self, keyarr, kind=None):

return None

def take(self, indexer, axis=0, allow_fill=True, fill_value=None):
"""
For internal compatibility with numpy arrays.

# filling must always be None/nan here
# but is passed thru internally
assert isnull(fill_value)

See also
--------
numpy.ndarray.take
"""

indexer = com._ensure_platform_int(indexer)
taken = self.codes.take(indexer)
@Appender(_index_shared_docs['take'])
def take(self, indices, axis=0, allow_fill=True, fill_value=None):
indices = com._ensure_platform_int(indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think you need this ensure here (as its inside assert_take_fillabel)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move it to assert_take_fillable, it will be executed duplicatelly when take has additional logic`` like below. No need to care?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I just realized that!

ok the platform_int thing is a whole other issue! (we shouldn't even use them at all), or convert JUST where needed (only for things like take).

taken = self._assert_take_fillable(self.codes, indices,
allow_fill=allow_fill,
fill_value=fill_value,
na_value=-1)
return self._create_from_codes(taken)

def delete(self, loc):
Expand Down
37 changes: 32 additions & 5 deletions pandas/indexes/multi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# pylint: disable=E1101,E1103,W0232
import datetime
import warnings
Expand All @@ -11,7 +12,7 @@

from pandas.compat import range, zip, lrange, lzip, map
from pandas import compat
from pandas.core.base import FrozenList
from pandas.core.base import FrozenList, FrozenNDArray
import pandas.core.base as base
from pandas.util.decorators import (Appender, cache_readonly,
deprecate, deprecate_kwarg)
Expand Down Expand Up @@ -1003,12 +1004,38 @@ def __getitem__(self, key):
names=self.names, sortorder=sortorder,
verify_integrity=False)

def take(self, indexer, axis=None):
indexer = com._ensure_platform_int(indexer)
new_labels = [lab.take(indexer) for lab in self.labels]
return MultiIndex(levels=self.levels, labels=new_labels,
@Appender(_index_shared_docs['take'])
def take(self, indices, axis=0, allow_fill=True, fill_value=None):
indices = com._ensure_platform_int(indices)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

taken = self._assert_take_fillable(self.labels, indices,
allow_fill=allow_fill,
fill_value=fill_value,
na_value=-1)
return MultiIndex(levels=self.levels, labels=taken,
names=self.names, verify_integrity=False)

def _assert_take_fillable(self, values, indices, allow_fill=True,
fill_value=None, na_value=None):
""" Internal method to handle NA filling of take """
# only fill if we are passing a non-None fill_value
if allow_fill and fill_value is not None:
if (indices < -1).any():
msg = ('When allow_fill=True and fill_value is not None, '
'all indices must be >= -1')
raise ValueError(msg)
taken = [lab.take(indices) for lab in self.labels]
mask = indices == -1
if mask.any():
masked = []
for new_label in taken:
label_values = new_label.values()
label_values[mask] = na_value
masked.append(base.FrozenNDArray(label_values))
taken = masked
else:
taken = [lab.take(indices) for lab in self.labels]
return taken

def append(self, other):
"""
Append a collection of Index options together
Expand Down
28 changes: 28 additions & 0 deletions pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,34 @@ def test_nan_first_take_datetime(self):
exp = Index([idx[-1], idx[0], idx[1]])
tm.assert_index_equal(res, exp)

def test_take_fill_value(self):
# GH 12631
idx = pd.Index(list('ABC'), name='xxx')
result = idx.take(np.array([1, 0, -1]))
expected = pd.Index(list('BAC'), name='xxx')
tm.assert_index_equal(result, expected)

# fill_value
result = idx.take(np.array([1, 0, -1]), fill_value=True)
expected = pd.Index(['B', 'A', np.nan], name='xxx')
tm.assert_index_equal(result, expected)

# allow_fill=False
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
fill_value=True)
expected = pd.Index(['B', 'A', 'C'], name='xxx')
tm.assert_index_equal(result, expected)

msg = ('When allow_fill=True and fill_value is not None, '
'all indices must be >= -1')
with tm.assertRaisesRegexp(ValueError, msg):
idx.take(np.array([1, 0, -2]), fill_value=True)
with tm.assertRaisesRegexp(ValueError, msg):
idx.take(np.array([1, 0, -5]), fill_value=True)

with tm.assertRaises(IndexError):
idx.take(np.array([1, -5]))

def test_reindex_preserves_name_if_target_is_list_or_ndarray(self):
# GH6552
idx = pd.Index([0, 1, 2])
Expand Down
97 changes: 97 additions & 0 deletions pandas/tests/indexes/test_category.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,3 +708,100 @@ def test_fillna_categorical(self):
with tm.assertRaisesRegexp(ValueError,
'fill value must be in categories'):
idx.fillna(2.0)

def test_take_fill_value(self):
# GH 12631

# numeric category
idx = pd.CategoricalIndex([1, 2, 3], name='xxx')
result = idx.take(np.array([1, 0, -1]))
expected = pd.CategoricalIndex([2, 1, 3], name='xxx')
tm.assert_index_equal(result, expected)
tm.assert_categorical_equal(result.values, expected.values)

# fill_value
result = idx.take(np.array([1, 0, -1]), fill_value=True)
expected = pd.CategoricalIndex([2, 1, np.nan], categories=[1, 2, 3],
name='xxx')
tm.assert_index_equal(result, expected)
tm.assert_categorical_equal(result.values, expected.values)

# allow_fill=False
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
fill_value=True)
expected = pd.CategoricalIndex([2, 1, 3], name='xxx')
tm.assert_index_equal(result, expected)
tm.assert_categorical_equal(result.values, expected.values)

# object category
idx = pd.CategoricalIndex(list('CBA'), categories=list('ABC'),
ordered=True, name='xxx')
result = idx.take(np.array([1, 0, -1]))
expected = pd.CategoricalIndex(list('BCA'), categories=list('ABC'),
ordered=True, name='xxx')
tm.assert_index_equal(result, expected)
tm.assert_categorical_equal(result.values, expected.values)

# fill_value
result = idx.take(np.array([1, 0, -1]), fill_value=True)
expected = pd.CategoricalIndex(['B', 'C', np.nan],
categories=list('ABC'), ordered=True,
name='xxx')
tm.assert_index_equal(result, expected)
tm.assert_categorical_equal(result.values, expected.values)

# allow_fill=False
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
fill_value=True)
expected = pd.CategoricalIndex(list('BCA'), categories=list('ABC'),
ordered=True, name='xxx')
tm.assert_index_equal(result, expected)
tm.assert_categorical_equal(result.values, expected.values)

msg = ('When allow_fill=True and fill_value is not None, '
'all indices must be >= -1')
with tm.assertRaisesRegexp(ValueError, msg):
idx.take(np.array([1, 0, -2]), fill_value=True)
with tm.assertRaisesRegexp(ValueError, msg):
idx.take(np.array([1, 0, -5]), fill_value=True)

with tm.assertRaises(IndexError):
idx.take(np.array([1, -5]))

def test_take_fill_value_datetime(self):

# datetime category
idx = pd.DatetimeIndex(['2011-01-01', '2011-02-01', '2011-03-01'],
name='xxx')
idx = pd.CategoricalIndex(idx)
result = idx.take(np.array([1, 0, -1]))
expected = pd.DatetimeIndex(['2011-02-01', '2011-01-01', '2011-03-01'],
name='xxx')
expected = pd.CategoricalIndex(expected)
tm.assert_index_equal(result, expected)

# fill_value
result = idx.take(np.array([1, 0, -1]), fill_value=True)
expected = pd.DatetimeIndex(['2011-02-01', '2011-01-01', 'NaT'],
name='xxx')
exp_cats = pd.DatetimeIndex(['2011-01-01', '2011-02-01', '2011-03-01'])
expected = pd.CategoricalIndex(expected, categories=exp_cats)
tm.assert_index_equal(result, expected)

# allow_fill=False
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
fill_value=True)
expected = pd.DatetimeIndex(['2011-02-01', '2011-01-01', '2011-03-01'],
name='xxx')
expected = pd.CategoricalIndex(expected)
tm.assert_index_equal(result, expected)

msg = ('When allow_fill=True and fill_value is not None, '
'all indices must be >= -1')
with tm.assertRaisesRegexp(ValueError, msg):
idx.take(np.array([1, 0, -2]), fill_value=True)
with tm.assertRaisesRegexp(ValueError, msg):
idx.take(np.array([1, 0, -5]), fill_value=True)

with tm.assertRaises(IndexError):
idx.take(np.array([1, -5]))
40 changes: 40 additions & 0 deletions pandas/tests/indexes/test_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,46 @@ def test_take_preserve_name(self):
taken = self.index.take([3, 0, 1])
self.assertEqual(taken.names, self.index.names)

def test_take_fill_value(self):
# GH 12631
vals = [['A', 'B'],
[pd.Timestamp('2011-01-01'), pd.Timestamp('2011-01-02')]]
idx = pd.MultiIndex.from_product(vals, names=['str', 'dt'])

result = idx.take(np.array([1, 0, -1]))
exp_vals = [('A', pd.Timestamp('2011-01-02')),
('A', pd.Timestamp('2011-01-01')),
('B', pd.Timestamp('2011-01-02'))]
expected = pd.MultiIndex.from_tuples(exp_vals, names=['str', 'dt'])
tm.assert_index_equal(result, expected)

# fill_value
result = idx.take(np.array([1, 0, -1]), fill_value=True)
exp_vals = [('A', pd.Timestamp('2011-01-02')),
('A', pd.Timestamp('2011-01-01')),
(np.nan, pd.NaT)]
expected = pd.MultiIndex.from_tuples(exp_vals, names=['str', 'dt'])
tm.assert_index_equal(result, expected)

# allow_fill=False
result = idx.take(np.array([1, 0, -1]), allow_fill=False,
fill_value=True)
exp_vals = [('A', pd.Timestamp('2011-01-02')),
('A', pd.Timestamp('2011-01-01')),
('B', pd.Timestamp('2011-01-02'))]
expected = pd.MultiIndex.from_tuples(exp_vals, names=['str', 'dt'])
tm.assert_index_equal(result, expected)

msg = ('When allow_fill=True and fill_value is not None, '
'all indices must be >= -1')
with tm.assertRaisesRegexp(ValueError, msg):
idx.take(np.array([1, 0, -2]), fill_value=True)
with tm.assertRaisesRegexp(ValueError, msg):
idx.take(np.array([1, 0, -5]), fill_value=True)

with tm.assertRaises(IndexError):
idx.take(np.array([1, -5]))

def test_join_level(self):
def _check_how(other, how):
join_index, lidx, ridx = other.join(self.index, how=how,
Expand Down
Loading