Skip to content

BUG: Allow MultiIndex to be subclassed #11267 #11268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 70 additions & 70 deletions pandas/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3906,7 +3906,7 @@ def __new__(cls, levels=None, labels=None, sortorder=None, names=None,
name = None
return Index(levels[0], name=name, copy=True).take(labels[0])

result = object.__new__(MultiIndex)
result = object.__new__(cls)

# we've already validated levels and labels, so shortcut here
result._set_levels(levels, copy=copy, validate=False)
Expand Down Expand Up @@ -4184,12 +4184,12 @@ def copy(self, names=None, dtype=None, levels=None, labels=None,
levels = self.levels
labels = self.labels
names = self.names
return MultiIndex(levels=levels,
labels=labels,
names=names,
sortorder=self.sortorder,
verify_integrity=False,
_set_identity=_set_identity)
return self.__class__(levels=levels,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be made consistent, e.g. use @constructor, so this is a no-go (you may need to rejigger where the @constructor is used now as its MultiIndex.from_tuples)

labels=labels,
names=names,
sortorder=self.sortorder,
verify_integrity=False,
_set_identity=_set_identity)

def __array__(self, dtype=None):
""" the array interface, return my values """
Expand All @@ -4205,7 +4205,7 @@ def _shallow_copy(self, values=None, infer=False, **kwargs):
if values is not None:
if 'name' in kwargs:
kwargs['names'] = kwargs.pop('name',None)
return MultiIndex.from_tuples(values, **kwargs)
return self.__class__.from_tuples(values, **kwargs)
return self.view()

@cache_readonly
Expand Down Expand Up @@ -4285,16 +4285,16 @@ def _format_native_types(self, **kwargs):

@property
def _constructor(self):
return MultiIndex.from_tuples
return self.__class__.from_tuples

@cache_readonly
def inferred_type(self):
return 'mixed'

@staticmethod
def _from_elements(values, labels=None, levels=None, names=None,
@classmethod
def _from_elements(cls, values, labels=None, levels=None, names=None,
sortorder=None):
return MultiIndex(levels, labels, names, sortorder=sortorder)
return cls(levels, labels, names, sortorder=sortorder)

def _get_level_number(self, level):
try:
Expand Down Expand Up @@ -4552,7 +4552,7 @@ def to_hierarchical(self, n_repeat, n_shuffle=1):
# Assumes that each label is divisible by n_shuffle
labels = [x.reshape(n_shuffle, -1).ravel(1) for x in labels]
names = self.names
return MultiIndex(levels=levels, labels=labels, names=names)
return self.__class__(levels=levels, labels=labels, names=names)

@property
def is_all_dates(self):
Expand Down Expand Up @@ -4626,9 +4626,9 @@ def from_arrays(cls, arrays, sortorder=None, names=None):
if names is None:
names = [getattr(arr, "name", None) for arr in arrays]

return MultiIndex(levels=levels, labels=labels,
sortorder=sortorder, names=names,
verify_integrity=False)
return cls(levels=levels, labels=labels,
sortorder=sortorder, names=names,
verify_integrity=False)

@classmethod
def from_tuples(cls, tuples, sortorder=None, names=None):
Expand Down Expand Up @@ -4673,8 +4673,8 @@ def from_tuples(cls, tuples, sortorder=None, names=None):
else:
arrays = lzip(*tuples)

return MultiIndex.from_arrays(arrays, sortorder=sortorder,
names=names)
return cls.from_arrays(arrays, sortorder=sortorder,
names=names)

@classmethod
def from_product(cls, iterables, sortorder=None, names=None):
Expand Down Expand Up @@ -4716,8 +4716,8 @@ def from_product(cls, iterables, sortorder=None, names=None):
categoricals = [Categorical.from_array(it, ordered=True) for it in iterables]
labels = cartesian_product([c.codes for c in categoricals])

return MultiIndex(levels=[c.categories for c in categoricals],
labels=labels, sortorder=sortorder, names=names)
return cls(levels=[c.categories for c in categoricals],
labels=labels, sortorder=sortorder, names=names)

@property
def nlevels(self):
Expand Down Expand Up @@ -4785,17 +4785,17 @@ def __getitem__(self, key):

new_labels = [lab[key] for lab in self.labels]

return MultiIndex(levels=self.levels,
labels=new_labels,
names=self.names,
sortorder=sortorder,
verify_integrity=False)
return self.__class__(levels=self.levels,
labels=new_labels,
names=self.names,
sortorder=sortorder,
verify_integrity=False)

def take(self, indexer, axis=None):
indexer = com._ensure_platform_int(indexer)
new_labels = [lab.take(indexer) for lab in self.labels]
return MultiIndex(levels=self.levels, labels=new_labels,
names=self.names, verify_integrity=False)
return self.__class__(levels=self.levels, labels=new_labels,
names=self.names, verify_integrity=False)

def append(self, other):
"""
Expand All @@ -4818,26 +4818,26 @@ def append(self, other):
label = self.get_level_values(i)
appended = [o.get_level_values(i) for o in other]
arrays.append(label.append(appended))
return MultiIndex.from_arrays(arrays, names=self.names)
return self.__class__.from_arrays(arrays, names=self.names)

to_concat = (self.values,) + tuple(k._values for k in other)
new_tuples = np.concatenate(to_concat)

# if all(isinstance(x, MultiIndex) for x in other):
try:
return MultiIndex.from_tuples(new_tuples, names=self.names)
return self.__class__.from_tuples(new_tuples, names=self.names)
except:
return Index(new_tuples)

def argsort(self, *args, **kwargs):
return self.values.argsort(*args, **kwargs)

def repeat(self, n):
return MultiIndex(levels=self.levels,
labels=[label.view(np.ndarray).repeat(n) for label in self.labels],
names=self.names,
sortorder=self.sortorder,
verify_integrity=False)
return self.__class__(levels=self.levels,
labels=[label.view(np.ndarray).repeat(n) for label in self.labels],
names=self.names,
sortorder=self.sortorder,
verify_integrity=False)

def drop(self, labels, level=None, errors='raise'):
"""
Expand Down Expand Up @@ -4936,8 +4936,8 @@ def droplevel(self, level=0):
result.name = new_names[0]
return result
else:
return MultiIndex(levels=new_levels, labels=new_labels,
names=new_names, verify_integrity=False)
return self.__class__(levels=new_levels, labels=new_labels,
names=new_names, verify_integrity=False)

def swaplevel(self, i, j):
"""
Expand All @@ -4963,8 +4963,8 @@ def swaplevel(self, i, j):
new_labels[i], new_labels[j] = new_labels[j], new_labels[i]
new_names[i], new_names[j] = new_names[j], new_names[i]

return MultiIndex(levels=new_levels, labels=new_labels,
names=new_names, verify_integrity=False)
return self.__class__(levels=new_levels, labels=new_labels,
names=new_names, verify_integrity=False)

def reorder_levels(self, order):
"""
Expand All @@ -4982,8 +4982,8 @@ def reorder_levels(self, order):
new_labels = [self.labels[i] for i in order]
new_names = [self.names[i] for i in order]

return MultiIndex(levels=new_levels, labels=new_labels,
names=new_names, verify_integrity=False)
return self.__class__(levels=new_levels, labels=new_labels,
names=new_names, verify_integrity=False)

def __getslice__(self, i, j):
return self.__getitem__(slice(i, j))
Expand Down Expand Up @@ -5048,9 +5048,9 @@ def sortlevel(self, level=0, ascending=True, sort_remaining=True):
indexer = com._ensure_platform_int(indexer)
new_labels = [lab.take(indexer) for lab in self.labels]

new_index = MultiIndex(labels=new_labels, levels=self.levels,
names=self.names, sortorder=sortorder,
verify_integrity=False)
new_index = self.__class__(labels=new_labels, levels=self.levels,
names=self.names, sortorder=sortorder,
verify_integrity=False)

return new_index, indexer

Expand Down Expand Up @@ -5165,7 +5165,7 @@ def reindex(self, target, method=None, level=None, limit=None,
target = self.take(indexer)
else:
# hopefully?
target = MultiIndex.from_tuples(target)
target = self.__class__.from_tuples(target)

if (preserve_names and target.nlevels == self.nlevels and
target.names != self.names):
Expand Down Expand Up @@ -5665,8 +5665,8 @@ def truncate(self, before=None, after=None):
new_labels = [lab[left:right] for lab in self.labels]
new_labels[0] = new_labels[0] - i

return MultiIndex(levels=new_levels, labels=new_labels,
verify_integrity=False)
return self.__class__(levels=new_levels, labels=new_labels,
verify_integrity=False)

def equals(self, other):
"""
Expand Down Expand Up @@ -5734,8 +5734,8 @@ def union(self, other):
return self

uniq_tuples = lib.fast_unique_multiple([self._values, other._values])
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)
return self.__class__.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)

def intersection(self, other):
"""
Expand All @@ -5759,12 +5759,12 @@ def intersection(self, other):
other_tuples = other._values
uniq_tuples = sorted(set(self_tuples) & set(other_tuples))
if len(uniq_tuples) == 0:
return MultiIndex(levels=[[]] * self.nlevels,
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)
return self.__class__(levels=[[]] * self.nlevels,
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)
else:
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)
return self.__class__.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)

def difference(self, other):
"""
Expand All @@ -5781,19 +5781,19 @@ def difference(self, other):
return self

if self.equals(other):
return MultiIndex(levels=[[]] * self.nlevels,
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)
return self.__class__(levels=[[]] * self.nlevels,
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)

difference = sorted(set(self._values) - set(other._values))

if len(difference) == 0:
return MultiIndex(levels=[[]] * self.nlevels,
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)
return self.__class__(levels=[[]] * self.nlevels,
labels=[[]] * self.nlevels,
names=result_names, verify_integrity=False)
else:
return MultiIndex.from_tuples(difference, sortorder=0,
names=result_names)
return self.__class__.from_tuples(difference, sortorder=0,
names=result_names)

def astype(self, dtype):
if not is_object_dtype(np.dtype(dtype)):
Expand All @@ -5806,13 +5806,13 @@ def _convert_can_do_setop(self, other):

if not hasattr(other, 'names'):
if len(other) == 0:
other = MultiIndex(levels=[[]] * self.nlevels,
labels=[[]] * self.nlevels,
verify_integrity=False)
other = self.__class__(levels=[[]] * self.nlevels,
labels=[[]] * self.nlevels,
verify_integrity=False)
else:
msg = 'other must be a MultiIndex or a list of tuples'
try:
other = MultiIndex.from_tuples(other)
other = self.__class__.from_tuples(other)
except:
raise TypeError(msg)
else:
Expand Down Expand Up @@ -5856,8 +5856,8 @@ def insert(self, loc, item):
new_levels.append(level)
new_labels.append(np.insert(_ensure_int64(labels), loc, lev_loc))

return MultiIndex(levels=new_levels, labels=new_labels,
names=self.names, verify_integrity=False)
return self.__class__(levels=new_levels, labels=new_labels,
names=self.names, verify_integrity=False)

def delete(self, loc):
"""
Expand All @@ -5868,8 +5868,8 @@ def delete(self, loc):
new_index : MultiIndex
"""
new_labels = [np.delete(lab, loc) for lab in self.labels]
return MultiIndex(levels=self.levels, labels=new_labels,
names=self.names, verify_integrity=False)
return self.__class__(levels=self.levels, labels=new_labels,
names=self.names, verify_integrity=False)

get_major_bounds = slice_locs

Expand All @@ -5889,7 +5889,7 @@ def _bounds(self):

def _wrap_joined_index(self, joined, other):
names = self.names if self.names == other.names else None
return MultiIndex.from_tuples(joined, names=names)
return self.__class__.from_tuples(joined, names=names)

@Appender(Index.isin.__doc__)
def isin(self, values, level=None):
Expand Down
7 changes: 7 additions & 0 deletions pandas/tests/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5673,6 +5673,13 @@ def test_equals_operator(self):
# GH9785
self.assertTrue((self.index == self.index).all())

def test_subclassing(self):
# GH11267
class MyMultiIndex(MultiIndex):
pass
mi = MyMultiIndex([['a'], ['b']], [[0], [0]])
self.assertTrue(isinstance(mi, MyMultiIndex))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs a fair bit of testing, e.g. run thru the some typical operations of a MI, just with a sub-class one.



def test_get_combined_index():
from pandas.core.index import _get_combined_index
Expand Down