diff --git a/doc/source/whatsnew/v0.25.0.rst b/doc/source/whatsnew/v0.25.0.rst index 09626be713c4f..b96fa5bf360a2 100644 --- a/doc/source/whatsnew/v0.25.0.rst +++ b/doc/source/whatsnew/v0.25.0.rst @@ -20,7 +20,7 @@ Other Enhancements ^^^^^^^^^^^^^^^^^^ - :meth:`Timestamp.replace` now supports the ``fold`` argument to disambiguate DST transition times (:issue:`25017`) -- +- :meth:`TimedeltaIndex.intersection` now also supports the ``sort`` keyword. - .. _whatsnew_0250.api_breaking: @@ -88,7 +88,7 @@ Datetimelike Timedelta ^^^^^^^^^ -- +- Bug in :func:`TimedeltaIndex.intersection` where for non-monotonic indicies in some cases an empty Index was returned when in fact an intersection existed. - - diff --git a/pandas/core/indexes/datetimelike.py b/pandas/core/indexes/datetimelike.py index aa7332472fc07..82c4c618e4ff0 100644 --- a/pandas/core/indexes/datetimelike.py +++ b/pandas/core/indexes/datetimelike.py @@ -27,6 +27,7 @@ from pandas.core.tools.timedeltas import to_timedelta import pandas.io.formats.printing as printing +from pandas.tseries.frequencies import to_offset _index_doc_kwargs = dict(ibase._index_doc_kwargs) @@ -530,6 +531,64 @@ def isin(self, values): return algorithms.isin(self.asi8, values.asi8) + def intersection(self, other, sort=False): + self._validate_sort_keyword(sort) + self._assert_can_do_setop(other) + + if self.equals(other): + return self._get_reconciled_name_object(other) + + if not isinstance(other, type(self)): + try: + other = self(other) + except (TypeError, ValueError): + pass + result = Index.intersection(self, other, sort=sort) + if isinstance(result, type(self)): + if result.freq is None: + result.freq = to_offset(result.inferred_freq) + return result + + elif (other.freq is None or self.freq is None or + other.freq != self.freq or + not other.freq.isAnchored() or + (not self.is_monotonic or not other.is_monotonic)): + result = Index.intersection(self, other, sort=sort) + + # Invalidate the freq of `result`, which may not be correct at + # this point, depending on the values. + result.freq = None + if hasattr(self, 'tz'): + result = self._shallow_copy(result._values, name=result.name, + tz=result.tz, freq=None) + else: + result = self._shallow_copy(result._values, name=result.name, + freq=None) + if result.freq is None: + result.freq = to_offset(result.inferred_freq) + return result + + if len(self) == 0: + return self + if len(other) == 0: + return other + + # to make our life easier, "sort" the two ranges + if self[0] <= other[0]: + left, right = self, other + else: + left, right = other, self + + end = min(left[-1], right[-1]) + start = right[0] + + if end < start: + return type(self)(data=[]) + else: + lslice = slice(*left.slice_locs(start, end)) + left_chunk = left.values[lslice] + return self._shallow_copy(left_chunk) + @Appender(_index_shared_docs['repeat'] % _index_doc_kwargs) def repeat(self, repeats, axis=None): nv.validate_repeat(tuple(), dict(axis=axis)) diff --git a/pandas/core/indexes/datetimes.py b/pandas/core/indexes/datetimes.py index 9c46860eb49d6..6bfe657c1b74a 100644 --- a/pandas/core/indexes/datetimes.py +++ b/pandas/core/indexes/datetimes.py @@ -590,14 +590,10 @@ def _fast_union(self, other): else: return left - def _wrap_setop_result(self, other, result): - name = get_op_result_name(self, other) - return self._shallow_copy(result, name=name, freq=None, tz=self.tz) - def intersection(self, other, sort=False): """ - Specialized intersection for DatetimeIndex objects. May be much faster - than Index.intersection + Specialized intersection for DatetimeIndex objects. + May be much faster than Index.intersection Parameters ---------- @@ -614,58 +610,13 @@ def intersection(self, other, sort=False): Returns ------- - y : Index or DatetimeIndex + y : Index or DatetimeIndex or TimedeltaIndex """ - self._validate_sort_keyword(sort) - self._assert_can_do_setop(other) - - if self.equals(other): - return self._get_reconciled_name_object(other) - - if not isinstance(other, DatetimeIndex): - try: - other = DatetimeIndex(other) - except (TypeError, ValueError): - pass - result = Index.intersection(self, other, sort=sort) - if isinstance(result, DatetimeIndex): - if result.freq is None: - result.freq = to_offset(result.inferred_freq) - return result + return super(DatetimeIndex, self).intersection(other, sort=sort) - elif (other.freq is None or self.freq is None or - other.freq != self.freq or - not other.freq.isAnchored() or - (not self.is_monotonic or not other.is_monotonic)): - result = Index.intersection(self, other, sort=sort) - # Invalidate the freq of `result`, which may not be correct at - # this point, depending on the values. - result.freq = None - result = self._shallow_copy(result._values, name=result.name, - tz=result.tz, freq=None) - if result.freq is None: - result.freq = to_offset(result.inferred_freq) - return result - - if len(self) == 0: - return self - if len(other) == 0: - return other - # to make our life easier, "sort" the two ranges - if self[0] <= other[0]: - left, right = self, other - else: - left, right = other, self - - end = min(left[-1], right[-1]) - start = right[0] - - if end < start: - return type(self)(data=[]) - else: - lslice = slice(*left.slice_locs(start, end)) - left_chunk = left.values[lslice] - return self._shallow_copy(left_chunk) + def _wrap_setop_result(self, other, result): + name = get_op_result_name(self, other) + return self._shallow_copy(result, name=name, freq=None, tz=self.tz) # -------------------------------------------------------------------- diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index a4bd7f9017eb4..0ddf7ec106698 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -802,6 +802,10 @@ def join(self, other, how='left', level=None, return_indexers=False, return self._apply_meta(result), lidx, ridx return self._apply_meta(result) + @Appender(Index.intersection.__doc__) + def intersection(self, other, sort=False): + return Index.intersection(self, other, sort=sort) + def _assert_can_do_setop(self, other): super(PeriodIndex, self)._assert_can_do_setop(other) diff --git a/pandas/core/indexes/timedeltas.py b/pandas/core/indexes/timedeltas.py index cbe5ae198838f..365a0a4a48f3d 100644 --- a/pandas/core/indexes/timedeltas.py +++ b/pandas/core/indexes/timedeltas.py @@ -378,6 +378,34 @@ def join(self, other, how='left', level=None, return_indexers=False, return_indexers=return_indexers, sort=sort) + def intersection(self, other, sort=False): + """ + Specialized intersection for TimedeltaIndex objects. + May be much faster than Index.intersection + + Parameters + ---------- + other : TimedeltaIndex or array-like + sort : False or None, default False + Sort the resulting index if possible. + + .. versionadded:: 0.24.0 + + .. versionchanged:: 0.24.1 + + Changed the default to ``False`` to match the behaviour + from before 0.24.0. + + .. versionchanged:: 0.25.0 + + The `sort` keyword has been added to TimedeltaIndex as well. + + Returns + ------- + y : Index or DatetimeIndex or TimedeltaIndex + """ + return super(TimedeltaIndex, self).intersection(other, sort=sort) + def _wrap_joined_index(self, joined, other): name = get_op_result_name(self, other) if (isinstance(other, TimedeltaIndex) and self.freq == other.freq and @@ -439,52 +467,6 @@ def _fast_union(self, other): else: return left - def intersection(self, other): - """ - Specialized intersection for TimedeltaIndex objects. May be much faster - than Index.intersection - - Parameters - ---------- - other : TimedeltaIndex or array-like - - Returns - ------- - y : Index or TimedeltaIndex - """ - self._assert_can_do_setop(other) - - if self.equals(other): - return self._get_reconciled_name_object(other) - - if not isinstance(other, TimedeltaIndex): - try: - other = TimedeltaIndex(other) - except (TypeError, ValueError): - pass - result = Index.intersection(self, other) - return result - - if len(self) == 0: - return self - if len(other) == 0: - return other - # to make our life easier, "sort" the two ranges - if self[0] <= other[0]: - left, right = self, other - else: - left, right = other, self - - end = min(left[-1], right[-1]) - start = right[0] - - if end < start: - return type(self)(data=[]) - else: - lslice = slice(*left.slice_locs(start, end)) - left_chunk = left.values[lslice] - return self._shallow_copy(left_chunk) - def _maybe_promote(self, other): if other.inferred_type == 'timedelta': other = TimedeltaIndex(other) diff --git a/pandas/tests/indexes/timedeltas/test_setops.py b/pandas/tests/indexes/timedeltas/test_setops.py index f7c3f764df0a0..28fb3479a2347 100644 --- a/pandas/tests/indexes/timedeltas/test_setops.py +++ b/pandas/tests/indexes/timedeltas/test_setops.py @@ -1,9 +1,12 @@ import numpy as np +import pytest import pandas as pd from pandas import Int64Index, TimedeltaIndex, timedelta_range import pandas.util.testing as tm +from pandas.tseries.offsets import Hour + class TestTimedeltaIndex(object): @@ -73,3 +76,90 @@ def test_intersection_bug_1708(self): result = index_1 & index_2 expected = timedelta_range('1 day 01:00:00', periods=3, freq='h') tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("sort", [None, False]) + def test_intersection_equal(self, sort): + # for equal indicies intersection should return the original index + first = timedelta_range('1 day', periods=4, freq='h') + second = timedelta_range('1 day', periods=4, freq='h') + intersect = first.intersection(second, sort=sort) + if sort is None: + tm.assert_index_equal(intersect, second.sort_values()) + assert tm.equalContents(intersect, second) + + # Corner cases + inter = first.intersection(first, sort=sort) + assert inter is first + + @pytest.mark.parametrize("period_1, period_2", [(0, 4), (4, 0)]) + @pytest.mark.parametrize("sort", [None, False]) + def test_intersection_zero_length(self, period_1, period_2, sort): + index_1 = timedelta_range('1 day', periods=period_1, freq='h') + index_2 = timedelta_range('1 day', periods=period_2, freq='h') + inter = index_1.intersection(index_2, sort=sort) + tm.assert_index_equal(timedelta_range('1 day', periods=0, freq='h'), + inter) + + @pytest.mark.parametrize("rng, expected", + # if target has the same name, it is preserved + [(timedelta_range('1 day', periods=5, + freq='h', name='idx'), + timedelta_range('1 day', periods=4, + freq='h', name='idx')), + # if target name is different, it will be reset + (timedelta_range('1 day', periods=5, + freq='h', name='other'), + timedelta_range('1 day', periods=4, + freq='h', name=None)), + # if no overlap exists return empty index + (timedelta_range('1 day', periods=10, + freq='h', name='idx')[5:], + TimedeltaIndex([], name='idx')) + ]) + @pytest.mark.parametrize("sort", [None, False]) + def test_intersection(self, rng, expected, sort): + # GH 4690 (with tz) + base = timedelta_range('1 day', periods=4, freq='h', name='idx') + result = base.intersection(rng, sort=sort) + if sort is None: + expected = expected.sort_values() + tm.assert_index_equal(result, expected) + assert result.name == expected.name + assert result.freq == expected.freq + + @pytest.mark.parametrize("rng, expected", + # part intersection works + [(TimedeltaIndex(['5 hour', '2 hour', + '4 hour', '9 hour'], + name='idx'), + TimedeltaIndex(['2 hour', '4 hour'], + name='idx')), + # reordered part intersection + (TimedeltaIndex(['2 hour', '5 hour', + '5 hour', '1 hour'], + name='other'), + TimedeltaIndex(['1 hour', '2 hour'], + name=None)), + # reveresed index + (TimedeltaIndex(['1 hour', '2 hour', + '4 hour', '3 hour'], + name='idx')[::-1], + TimedeltaIndex(['1 hour', '2 hour', + '4 hour', '3 hour'], + name='idx'))]) + @pytest.mark.parametrize("sort", [None, False]) + def test_intersection_non_monotonic(self, rng, expected, sort): + # non-monotonic + base = TimedeltaIndex(['1 hour', '2 hour', '4 hour', '3 hour'], + name='idx') + result = base.intersection(rng, sort=sort) + if sort is None: + expected = expected.sort_values() + tm.assert_index_equal(result, expected) + assert result.name == expected.name + + # if reveresed order, frequency is still the same + if all(base == rng[::-1]) and sort is None: + assert isinstance(result.freq, Hour) + else: + assert result.freq is None