Skip to content

POC: add closed argument to IndexSlice #27209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions pandas/core/dtypes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _check(cls, inst):
("extension", "categorical", "periodarray", "datetimearray", "timedeltaarray"),
)
ABCPandasArray = create_pandas_abc_type("ABCPandasArray", "_typ", ("npy_extension",))
ABCIndexSlice = create_pandas_abc_type("ABCIndexSlc", "_typ", ("indexslice",))


class _ABCGeneric(type):
Expand Down
37 changes: 24 additions & 13 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ABCDateOffset,
ABCDatetimeArray,
ABCIndexClass,
ABCIndexSlice,
ABCMultiIndex,
ABCPandasArray,
ABCPeriodIndex,
Expand Down Expand Up @@ -3150,13 +3151,17 @@ def _convert_scalar_indexer(self, key, kind=None):
"""

@Appender(_index_shared_docs["_convert_slice_indexer"])
def _convert_slice_indexer(self, key, kind=None):
def _convert_slice_indexer(self, key, kind=None, closed=None):
assert kind in ["ix", "loc", "getitem", "iloc", None]

# if we are not a slice, then we are done
if not isinstance(key, slice):
if not isinstance(key, (slice, ABCIndexSlice)):
return key

if isinstance(key, ABCIndexSlice):
closed = key.closed
key = key.arg

# validate iloc
if kind == "iloc":
return slice(
Expand Down Expand Up @@ -3209,7 +3214,9 @@ def is_int(v):
indexer = key
else:
try:
indexer = self.slice_indexer(start, stop, step, kind=kind)
indexer = self.slice_indexer(
start, stop, step, kind=kind, closed=closed
)
except Exception:
if is_index_slice:
if self.is_integer():
Expand Down Expand Up @@ -4718,6 +4725,8 @@ def get_value(self, series, key):
raise
elif is_integer(key):
return s[key]
elif isinstance(key, ABCIndexSlice):
raise InvalidIndexError(key)

s = com.values_from_object(series)
k = com.values_from_object(key)
Expand Down Expand Up @@ -4990,7 +4999,7 @@ def _get_string_slice(self, key, use_lhs=True, use_rhs=True):
# overridden in DatetimeIndex, TimedeltaIndex and PeriodIndex
raise NotImplementedError

def slice_indexer(self, start=None, end=None, step=None, kind=None):
def slice_indexer(self, start=None, end=None, step=None, kind=None, closed=None):
"""
For an ordered or unique index, compute the slice indexer for input
labels and step.
Expand Down Expand Up @@ -5029,7 +5038,9 @@ def slice_indexer(self, start=None, end=None, step=None, kind=None):
>>> idx.slice_indexer(start='b', end=('c', 'g'))
slice(1, 3)
"""
start_slice, end_slice = self.slice_locs(start, end, step=step, kind=kind)
start_slice, end_slice = self.slice_locs(
start, end, step=step, kind=kind, closed=closed
)

# return a slice
if not is_scalar(start_slice):
Expand Down Expand Up @@ -5093,7 +5104,7 @@ def _validate_indexer(self, form, key, kind):
"""

@Appender(_index_shared_docs["_maybe_cast_slice_bound"])
def _maybe_cast_slice_bound(self, label, side, kind):
def _maybe_cast_slice_bound(self, label, side, kind, closed=None):
assert kind in ["ix", "loc", "getitem", None]

# We are a plain index here (sub-class override this method if they
Expand Down Expand Up @@ -5125,7 +5136,7 @@ def _searchsorted_monotonic(self, label, side="left"):

raise ValueError("index must be monotonic increasing or decreasing")

def get_slice_bound(self, label, side, kind):
def get_slice_bound(self, label, side, kind, closed=None):
"""
Calculate slice bound that corresponds to given label.

Expand Down Expand Up @@ -5155,7 +5166,7 @@ def get_slice_bound(self, label, side, kind):

# For datetime indices label may be a string that has to be converted
# to datetime boundary according to its resolution.
label = self._maybe_cast_slice_bound(label, side, kind)
label = self._maybe_cast_slice_bound(label, side, kind, closed=closed)

# we need to look up the label
try:
Expand Down Expand Up @@ -5187,11 +5198,11 @@ def get_slice_bound(self, label, side, kind):
return slc.stop
else:
if side == "right":
return slc + 1
return slc + 1 if closed not in ["left", "neither"] else slc
else:
return slc
return slc if closed not in ["right", "neither"] else slc + 1

def slice_locs(self, start=None, end=None, step=None, kind=None):
def slice_locs(self, start=None, end=None, step=None, kind=None, closed=None):
"""
Compute slice locations for input labels.

Expand Down Expand Up @@ -5243,13 +5254,13 @@ def slice_locs(self, start=None, end=None, step=None, kind=None):

start_slice = None
if start is not None:
start_slice = self.get_slice_bound(start, "left", kind)
start_slice = self.get_slice_bound(start, "left", kind, closed=closed)
if start_slice is None:
start_slice = 0

end_slice = None
if end is not None:
end_slice = self.get_slice_bound(end, "right", kind)
end_slice = self.get_slice_bound(end, "right", kind, closed=closed)
if end_slice is None:
end_slice = len(self)

Expand Down
11 changes: 7 additions & 4 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def get_loc(self, key, method=None, tolerance=None):
raise e
raise KeyError(key)

def _maybe_cast_slice_bound(self, label, side, kind):
def _maybe_cast_slice_bound(self, label, side, kind, closed=None):
"""
If label is a string, cast it to datetime according to resolution.

Expand Down Expand Up @@ -1111,7 +1111,10 @@ def _maybe_cast_slice_bound(self, label, side, kind):
# and length 1 index)
if self._is_strictly_monotonic_decreasing and len(self) > 1:
return upper if side == "left" else lower
return lower if side == "left" else upper
if side == "left":
return lower if closed not in ["right", "neither"] else upper
else:
return upper if closed not in ["left", "neither"] else lower
else:
return label

Expand All @@ -1121,7 +1124,7 @@ def _get_string_slice(self, key, use_lhs=True, use_rhs=True):
loc = self._partial_date_slice(reso, parsed, use_lhs=use_lhs, use_rhs=use_rhs)
return loc

def slice_indexer(self, start=None, end=None, step=None, kind=None):
def slice_indexer(self, start=None, end=None, step=None, kind=None, closed=None):
"""
Return indexer for specified label slice.
Index.slice_indexer, customized to handle time slicing.
Expand All @@ -1147,7 +1150,7 @@ def slice_indexer(self, start=None, end=None, step=None, kind=None):
raise KeyError("Cannot mix time and non-time slice keys")

try:
return Index.slice_indexer(self, start, end, step, kind=kind)
return Index.slice_indexer(self, start, end, step, kind=kind, closed=closed)
except KeyError:
# For historical reasons DatetimeIndex by default supports
# value-based partial (aka string) slices on non-monotonic arrays,
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def _convert_scalar_indexer(self, key, kind=None):
return super()._convert_scalar_indexer(key, kind=kind)
return key

def _maybe_cast_slice_bound(self, label, side, kind):
def _maybe_cast_slice_bound(self, label, side, kind, closed=None):
return getattr(self, side)._maybe_cast_slice_bound(label, side, kind)

@Appender(_index_shared_docs["_convert_list_indexer"])
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,13 +2508,13 @@ def reindex(self, target, method=None, level=None, limit=None, tolerance=None):

return target, indexer

def get_slice_bound(self, label, side, kind):
def get_slice_bound(self, label, side, kind, closed=None):

if not isinstance(label, tuple):
label = (label,)
return self._partial_tup_index(label, side=side)

def slice_locs(self, start=None, end=None, step=None, kind=None):
def slice_locs(self, start=None, end=None, step=None, kind=None, closed=None):
"""
For an ordered MultiIndex, compute the slice locations for input
labels.
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None, fastpath=None):
return cls._simple_new(subarr, name=name)

@Appender(_index_shared_docs["_maybe_cast_slice_bound"])
def _maybe_cast_slice_bound(self, label, side, kind):
def _maybe_cast_slice_bound(self, label, side, kind, closed=None):
assert kind in ["ix", "loc", "getitem", None]

# we will try to coerce to integers
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,7 @@ def get_loc(self, key, method=None, tolerance=None):
except KeyError:
raise KeyError(key)

def _maybe_cast_slice_bound(self, label, side, kind):
def _maybe_cast_slice_bound(self, label, side, kind, closed=None):
"""
If label is a string or a datetime, cast it to Period.ordinal according
to resolution.
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def get_loc(self, key, method=None, tolerance=None):
except (KeyError, ValueError):
raise KeyError(key)

def _maybe_cast_slice_bound(self, label, side, kind):
def _maybe_cast_slice_bound(self, label, side, kind, closed=None):
"""
If label is a string, cast it to timedelta according to resolution.

Expand Down
41 changes: 33 additions & 8 deletions pandas/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
is_sequence,
is_sparse,
)
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexSlice, ABCSeries
from pandas.core.dtypes.missing import _infer_fill_value, isna

import pandas.core.common as com
from pandas.core.index import Index, MultiIndex

_VALID_CLOSED = {"left", "right", "both", "neither"}


# the supported indexers
def get_indexers_list():
Expand Down Expand Up @@ -85,8 +87,23 @@ class _IndexSlice:
B1 10 11
"""

_typ = "indexslice"

def __init__(self, closed=None):
if closed is not None and closed not in _VALID_CLOSED:
msg = "invalid option for 'closed': {closed}".format(closed=closed)
raise ValueError(msg)
self.closed = closed

def __call__(self, closed=None):
return _IndexSlice(closed=closed)

def __getitem__(self, arg):
return arg
if self.closed is None:
return arg
else:
self.arg = arg
return self


IndexSlice = _IndexSlice()
Expand Down Expand Up @@ -1455,8 +1472,9 @@ def __getitem__(self, key):
# we by definition only have the 0th axis
axis = self.axis or 0

maybe_callable = com.apply_if_callable(key, self.obj)
return self._getitem_axis(maybe_callable, axis=axis)
if not isinstance(key, ABCIndexSlice):
key = com.apply_if_callable(key, self.obj)
return self._getitem_axis(key, axis=axis)

def _is_scalar_access(self, key):
raise NotImplementedError()
Expand All @@ -1478,18 +1496,25 @@ def _getbool_axis(self, key, axis=None):
except Exception as detail:
raise self._exception(detail)

def _get_slice_axis(self, slice_obj, axis=None):
def _get_slice_axis(self, slice_obj, axis=None, closed=None):
""" this is pretty simple as we just have to deal with labels """
if axis is None:
axis = self.axis or 0
if isinstance(slice_obj, ABCIndexSlice):
closed = slice_obj.closed
slice_obj = slice_obj.arg

obj = self.obj
if not need_slice(slice_obj):
return obj.copy(deep=False)

labels = obj._get_axis(axis)
indexer = labels.slice_indexer(
slice_obj.start, slice_obj.stop, slice_obj.step, kind=self.name
slice_obj.start,
slice_obj.stop,
slice_obj.step,
kind=self.name,
closed=closed,
)

if isinstance(indexer, slice):
Expand Down Expand Up @@ -1751,7 +1776,7 @@ def _validate_key(self, key, axis):
# slice of integers (only if in the labels)
# boolean

if isinstance(key, slice):
if isinstance(key, (slice, ABCIndexSlice)):
return

if com.is_bool_indexer(key):
Expand Down Expand Up @@ -1823,7 +1848,7 @@ def _getitem_axis(self, key, axis=None):
labels = self.obj._get_axis(axis)
key = self._get_partial_string_timestamp_match_key(key, labels)

if isinstance(key, slice):
if isinstance(key, (slice, ABCIndexSlice)):
self._validate_key(key, axis)
return self._get_slice_axis(key, axis=axis)
elif com.is_bool_indexer(key):
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ABCDataFrame,
ABCDatetimeArray,
ABCDatetimeIndex,
ABCIndexSlice,
ABCSeries,
ABCSparseArray,
ABCSparseSeries,
Expand Down Expand Up @@ -1069,7 +1070,8 @@ def _slice(self, slobj, axis=0, kind=None):
return self._get_values(slobj)

def __getitem__(self, key):
key = com.apply_if_callable(key, self)
if not isinstance(key, ABCIndexSlice):
key = com.apply_if_callable(key, self)
try:
result = self.index.get_value(self, key)

Expand Down Expand Up @@ -1117,7 +1119,7 @@ def __getitem__(self, key):

def _get_with(self, key):
# other: fancy integer or otherwise
if isinstance(key, slice):
if isinstance(key, (slice, ABCIndexSlice)):
indexer = self.index._convert_slice_indexer(key, kind="getitem")
return self._get_values(indexer)
elif isinstance(key, ABCDataFrame):
Expand Down
Loading