Skip to content

Commit

Permalink
BUG: MultiIndex intersection with sort=False does not preserve order (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-Francois Zinque authored Feb 12, 2020
1 parent 143b011 commit c2f3ce3
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 30 deletions.
39 changes: 39 additions & 0 deletions asv_bench/benchmarks/multiindex_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,43 @@ def time_equals_non_object_index(self):
self.mi_large_slow.equals(self.idx_non_object)


class SetOperations:

params = [
("monotonic", "non_monotonic"),
("datetime", "int", "string"),
("intersection", "union", "symmetric_difference"),
]
param_names = ["index_structure", "dtype", "method"]

def setup(self, index_structure, dtype, method):
N = 10 ** 5
level1 = range(1000)

level2 = date_range(start="1/1/2000", periods=N // 1000)
dates_left = MultiIndex.from_product([level1, level2])

level2 = range(N // 1000)
int_left = MultiIndex.from_product([level1, level2])

level2 = tm.makeStringIndex(N // 1000).values
str_left = MultiIndex.from_product([level1, level2])

data = {
"datetime": dates_left,
"int": int_left,
"string": str_left,
}

if index_structure == "non_monotonic":
data = {k: mi[::-1] for k, mi in data.items()}

data = {k: {"left": mi, "right": mi[:-1]} for k, mi in data.items()}
self.left = data[dtype]["left"]
self.right = data[dtype]["right"]

def time_operation(self, index_structure, dtype, method):
getattr(self.left, method)(self.right)


from .pandas_vb_common import setup # noqa: F401 isort:skip
10 changes: 10 additions & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ MultiIndex
index=[["a", "a", "b", "b"], [1, 2, 1, 2]])
# Rows are now ordered as the requested keys
df.loc[(['b', 'a'], [2, 1]), :]
- Bug in :meth:`MultiIndex.intersection` was not guaranteed to preserve order when ``sort=False``. (:issue:`31325`)

.. ipython:: python
left = pd.MultiIndex.from_arrays([["b", "a"], [2, 1]])
right = pd.MultiIndex.from_arrays([["a", "b", "c"], [1, 2, 3]])
# Common elements are now guaranteed to be ordered by the left side
left.intersection(right, sort=False)
-

I/O
Expand Down
20 changes: 17 additions & 3 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3314,9 +3314,23 @@ def intersection(self, other, sort=False):
if self.equals(other):
return self

self_tuples = self._ndarray_values
other_tuples = other._ndarray_values
uniq_tuples = set(self_tuples) & set(other_tuples)
lvals = self._ndarray_values
rvals = other._ndarray_values

uniq_tuples = None # flag whether _inner_indexer was succesful
if self.is_monotonic and other.is_monotonic:
try:
uniq_tuples = self._inner_indexer(lvals, rvals)[0]
sort = False # uniq_tuples is already sorted
except TypeError:
pass

if uniq_tuples is None:
other_uniq = set(rvals)
seen = set()
uniq_tuples = [
x for x in lvals if x in other_uniq and not (x in seen or seen.add(x))
]

if sort is None:
uniq_tuples = sorted(uniq_tuples)
Expand Down
50 changes: 23 additions & 27 deletions pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,41 @@ def test_set_ops_error_cases(idx, case, sort, method):


@pytest.mark.parametrize("sort", [None, False])
def test_intersection_base(idx, sort):
first = idx[:5]
second = idx[:3]
intersect = first.intersection(second, sort=sort)
@pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list])
def test_intersection_base(idx, sort, klass):
first = idx[2::-1] # first 3 elements reversed
second = idx[:5]

if sort is None:
tm.assert_index_equal(intersect, second.sort_values())
assert tm.equalContents(intersect, second)
if klass is not MultiIndex:
second = klass(second.values)

# GH 10149
cases = [klass(second.values) for klass in [np.array, Series, list]]
for case in cases:
result = first.intersection(case, sort=sort)
if sort is None:
tm.assert_index_equal(result, second.sort_values())
assert tm.equalContents(result, second)
intersect = first.intersection(second, sort=sort)
if sort is None:
expected = first.sort_values()
else:
expected = first
tm.assert_index_equal(intersect, expected)

msg = "other must be a MultiIndex or a list of tuples"
with pytest.raises(TypeError, match=msg):
first.intersection([1, 2, 3], sort=sort)


@pytest.mark.parametrize("sort", [None, False])
def test_union_base(idx, sort):
first = idx[3:]
@pytest.mark.parametrize("klass", [MultiIndex, np.array, Series, list])
def test_union_base(idx, sort, klass):
first = idx[::-1]
second = idx[:5]
everything = idx

if klass is not MultiIndex:
second = klass(second.values)

union = first.union(second, sort=sort)
if sort is None:
tm.assert_index_equal(union, everything.sort_values())
assert tm.equalContents(union, everything)

# GH 10149
cases = [klass(second.values) for klass in [np.array, Series, list]]
for case in cases:
result = first.union(case, sort=sort)
if sort is None:
tm.assert_index_equal(result, everything.sort_values())
assert tm.equalContents(result, everything)
expected = first.sort_values()
else:
expected = first
tm.assert_index_equal(union, expected)

msg = "other must be a MultiIndex or a list of tuples"
with pytest.raises(TypeError, match=msg):
Expand Down

0 comments on commit c2f3ce3

Please sign in to comment.