diff --git a/pandas/_libs/index.pyx b/pandas/_libs/index.pyx index 28d269a9a809e..74815f64360b9 100644 --- a/pandas/_libs/index.pyx +++ b/pandas/_libs/index.pyx @@ -17,8 +17,8 @@ cnp.import_array() cimport pandas._libs.util as util -from pandas._libs.tslibs.conversion cimport maybe_datetimelike_to_i8 from pandas._libs.tslibs.nattype cimport c_NaT as NaT +from pandas._libs.tslibs.c_timestamp cimport _Timestamp from pandas._libs.hashtable cimport HashTable @@ -409,20 +409,27 @@ cdef class DatetimeEngine(Int64Engine): cdef _get_box_dtype(self): return 'M8[ns]' + cdef int64_t _unbox_scalar(self, scalar) except? -1: + # NB: caller is responsible for ensuring tzawareness compat + # before we get here + if not (isinstance(scalar, _Timestamp) or scalar is NaT): + raise TypeError(scalar) + return scalar.value + def __contains__(self, object val): cdef: - int64_t loc + int64_t loc, conv + conv = self._unbox_scalar(val) if self.over_size_threshold and self.is_monotonic_increasing: if not self.is_unique: - return self._get_loc_duplicates(val) + return self._get_loc_duplicates(conv) values = self._get_index_values() - conv = maybe_datetimelike_to_i8(val) loc = values.searchsorted(conv, side='left') return values[loc] == conv self._ensure_mapping_populated() - return maybe_datetimelike_to_i8(val) in self.mapping + return conv in self.mapping cdef _get_index_values(self): return self.vgetter().view('i8') @@ -431,23 +438,26 @@ cdef class DatetimeEngine(Int64Engine): return algos.is_monotonic(values, timelike=True) cpdef get_loc(self, object val): + # NB: the caller is responsible for ensuring that we are called + # with either a Timestamp or NaT (Timedelta or NaT for TimedeltaEngine) + cdef: int64_t loc if is_definitely_invalid_key(val): raise TypeError + try: + conv = self._unbox_scalar(val) + except TypeError: + raise KeyError(val) + # Welcome to the spaghetti factory if self.over_size_threshold and self.is_monotonic_increasing: if not self.is_unique: - val = maybe_datetimelike_to_i8(val) - return self._get_loc_duplicates(val) + return self._get_loc_duplicates(conv) values = self._get_index_values() - try: - conv = maybe_datetimelike_to_i8(val) - loc = values.searchsorted(conv, side='left') - except TypeError: - raise KeyError(val) + loc = values.searchsorted(conv, side='left') if loc == len(values) or values[loc] != conv: raise KeyError(val) @@ -455,21 +465,12 @@ cdef class DatetimeEngine(Int64Engine): self._ensure_mapping_populated() if not self.unique: - val = maybe_datetimelike_to_i8(val) - return self._get_loc_duplicates(val) + return self._get_loc_duplicates(conv) try: - return self.mapping.get_item(val.value) + return self.mapping.get_item(conv) except KeyError: raise KeyError(val) - except AttributeError: - pass - - try: - val = maybe_datetimelike_to_i8(val) - return self.mapping.get_item(val) - except (TypeError, ValueError): - raise KeyError(val) def get_indexer(self, values): self._ensure_mapping_populated() @@ -496,6 +497,11 @@ cdef class TimedeltaEngine(DatetimeEngine): cdef _get_box_dtype(self): return 'm8[ns]' + cdef int64_t _unbox_scalar(self, scalar) except? -1: + if not (isinstance(scalar, Timedelta) or scalar is NaT): + raise TypeError(scalar) + return scalar.value + cdef class PeriodEngine(Int64Engine): diff --git a/pandas/_libs/tslibs/conversion.pxd b/pandas/_libs/tslibs/conversion.pxd index 36e6b14be182a..d4ae3fa8c5b99 100644 --- a/pandas/_libs/tslibs/conversion.pxd +++ b/pandas/_libs/tslibs/conversion.pxd @@ -25,6 +25,4 @@ cdef int64_t get_datetime64_nanos(object val) except? -1 cpdef int64_t pydt_to_i8(object pydt) except? -1 -cdef maybe_datetimelike_to_i8(object val) - cpdef datetime localize_pydatetime(datetime dt, object tz) diff --git a/pandas/_libs/tslibs/conversion.pyx b/pandas/_libs/tslibs/conversion.pyx index 2988d7bae9a5e..f22b7bb6a3687 100644 --- a/pandas/_libs/tslibs/conversion.pyx +++ b/pandas/_libs/tslibs/conversion.pyx @@ -202,31 +202,6 @@ def datetime_to_datetime64(object[:] values): return result, inferred_tz -cdef inline maybe_datetimelike_to_i8(object val): - """ - Try to convert to a nanosecond timestamp. Fall back to returning the - input value. - - Parameters - ---------- - val : object - - Returns - ------- - val : int64 timestamp or original input - """ - cdef: - npy_datetimestruct dts - try: - return val.value - except AttributeError: - if is_datetime64_object(val): - return get_datetime64_value(val) - elif PyDateTime_Check(val): - return convert_datetime_to_tsobject(val, None).value - return val - - # ---------------------------------------------------------------------- # _TSObject Conversion diff --git a/pandas/tests/indexes/test_engines.py b/pandas/tests/indexes/test_engines.py new file mode 100644 index 0000000000000..ee224c9c6ec89 --- /dev/null +++ b/pandas/tests/indexes/test_engines.py @@ -0,0 +1,57 @@ +import re + +import pytest + +import pandas as pd + + +class TestDatetimeEngine: + @pytest.mark.parametrize( + "scalar", + [ + pd.Timedelta(pd.Timestamp("2016-01-01").asm8.view("m8[ns]")), + pd.Timestamp("2016-01-01").value, + pd.Timestamp("2016-01-01").to_pydatetime(), + pd.Timestamp("2016-01-01").to_datetime64(), + ], + ) + def test_not_contains_requires_timestamp(self, scalar): + dti1 = pd.date_range("2016-01-01", periods=3) + dti2 = dti1.insert(1, pd.NaT) # non-monotonic + dti3 = dti1.insert(3, dti1[0]) # non-unique + dti4 = pd.date_range("2016-01-01", freq="ns", periods=2_000_000) + dti5 = dti4.insert(0, dti4[0]) # over size threshold, not unique + + msg = "|".join([re.escape(str(scalar)), re.escape(repr(scalar))]) + for dti in [dti1, dti2, dti3, dti4, dti5]: + with pytest.raises(TypeError, match=msg): + scalar in dti._engine + + with pytest.raises(KeyError, match=msg): + dti._engine.get_loc(scalar) + + +class TestTimedeltaEngine: + @pytest.mark.parametrize( + "scalar", + [ + pd.Timestamp(pd.Timedelta(days=42).asm8.view("datetime64[ns]")), + pd.Timedelta(days=42).value, + pd.Timedelta(days=42).to_pytimedelta(), + pd.Timedelta(days=42).to_timedelta64(), + ], + ) + def test_not_contains_requires_timestamp(self, scalar): + tdi1 = pd.timedelta_range("42 days", freq="9h", periods=1234) + tdi2 = tdi1.insert(1, pd.NaT) # non-monotonic + tdi3 = tdi1.insert(3, tdi1[0]) # non-unique + tdi4 = pd.timedelta_range("42 days", freq="ns", periods=2_000_000) + tdi5 = tdi4.insert(0, tdi4[0]) # over size threshold, not unique + + msg = "|".join([re.escape(str(scalar)), re.escape(repr(scalar))]) + for tdi in [tdi1, tdi2, tdi3, tdi4, tdi5]: + with pytest.raises(TypeError, match=msg): + scalar in tdi._engine + + with pytest.raises(KeyError, match=msg): + tdi._engine.get_loc(scalar)