From a3ec4e54e7a1ddbb1c8647891dcb318bcd255a24 Mon Sep 17 00:00:00 2001 From: tp Date: Sun, 2 May 2021 15:22:46 +0100 Subject: [PATCH] Add more numeric tests for NumIndex --- pandas/core/indexes/base.py | 26 +++++++++-- pandas/core/indexes/numeric.py | 4 +- pandas/tests/indexes/common.py | 15 ++++-- pandas/tests/indexes/numeric/test_numeric.py | 49 +++++++++++++++----- 4 files changed, 74 insertions(+), 20 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 034aaf7487a5e5..72c42a542bde12 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -442,7 +442,7 @@ def __new__( return Index._simple_new(data, name=name) # index-like - elif type(data) is NumIndex and dtype is None: + elif isinstance(data, NumIndex) and data._is_num_index() and dtype is None: return NumIndex(data, name=name, copy=copy) elif isinstance(data, (np.ndarray, Index, ABCSeries)): @@ -2407,6 +2407,26 @@ def is_all_dates(self) -> bool: ) return self._is_all_dates + def _is_num_index(self) -> bool: + """ + Whether self is a NumIndex, but not *not* Int64Index, UInt64Index, FloatIndex. + + Typically used to check if an operation should return NumIndex or plain Index. + """ + from pandas.core.indexes.numeric import ( + Float64Index, + Int64Index, + NumIndex, + UInt64Index, + ) + + if not isinstance(self, NumIndex): + return False + elif isinstance(self, (Int64Index, UInt64Index, Float64Index)): + return False + else: + return True + # -------------------------------------------------------------------- # Pickle Methods @@ -5488,8 +5508,8 @@ def map(self, mapper, na_action=None): # empty attributes["dtype"] = self.dtype - if type(self) is NumIndex: - return type(self)(new_values, **attributes) + if self._is_num_index() and issubclass(new_values.dtype.type, np.number): + return NumIndex(new_values, **attributes) return Index(new_values, **attributes) diff --git a/pandas/core/indexes/numeric.py b/pandas/core/indexes/numeric.py index b97b07fec482c1..f1b3340ecbee5c 100644 --- a/pandas/core/indexes/numeric.py +++ b/pandas/core/indexes/numeric.py @@ -288,11 +288,11 @@ def astype(self, dtype, copy=True): # TODO(jreback); this can change once we have an EA Index type # GH 13149 arr = astype_nansafe(self._values, dtype=dtype) - if isinstance(self, Float64Index): + if not self._is_num_index(): return Int64Index(arr, name=self.name) else: return NumIndex(arr, name=self.name, dtype=dtype) - elif is_categorical_dtype(dtype): + if is_categorical_dtype(dtype): from pandas import CategoricalIndex return CategoricalIndex(self, name=self.name, dtype=dtype, copy=copy) diff --git a/pandas/tests/indexes/common.py b/pandas/tests/indexes/common.py index ad1019004a2aaa..87d35d8abff527 100644 --- a/pandas/tests/indexes/common.py +++ b/pandas/tests/indexes/common.py @@ -12,7 +12,7 @@ from pandas.core.dtypes.dtypes import CategoricalDtype import pandas as pd -from pandas import ( # noqa +from pandas import ( CategoricalIndex, DatetimeIndex, Float64Index, @@ -29,6 +29,7 @@ ) import pandas._testing as tm from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin +from pandas.core.indexes.numeric import NumIndex class Base: @@ -343,12 +344,13 @@ def test_numpy_argsort(self, index): def test_repeat(self, simple_index): rep = 2 idx = simple_index.copy() - expected = Index(idx.values.repeat(rep), name=idx.name) + new_index_cls = type(idx) if not isinstance(idx, RangeIndex) else Int64Index + expected = new_index_cls(idx.values.repeat(rep), name=idx.name) tm.assert_index_equal(idx.repeat(rep), expected) idx = simple_index rep = np.arange(len(idx)) - expected = Index(idx.values.repeat(rep), name=idx.name) + expected = new_index_cls(idx.values.repeat(rep), name=idx.name) tm.assert_index_equal(idx.repeat(rep), expected) def test_numpy_repeat(self, simple_index): @@ -649,7 +651,12 @@ def test_map_dictlike(self, mapper, simple_index): tm.assert_index_equal(result, expected) # empty mappable - expected = Index([np.nan] * len(idx)) + if idx._is_num_index(): + new_index_cls = NumIndex + else: + new_index_cls = Float64Index + + expected = new_index_cls([np.nan] * len(idx)) result = idx.map(mapper(expected, idx)) tm.assert_index_equal(result, expected) diff --git a/pandas/tests/indexes/numeric/test_numeric.py b/pandas/tests/indexes/numeric/test_numeric.py index eff6e2e21620a0..a8357fc63ffb86 100644 --- a/pandas/tests/indexes/numeric/test_numeric.py +++ b/pandas/tests/indexes/numeric/test_numeric.py @@ -13,6 +13,7 @@ UInt64Index, ) import pandas._testing as tm +from pandas.core.indexes.numeric import NumIndex from pandas.tests.indexes.common import NumericBase @@ -20,10 +21,17 @@ class TestFloat64Index(NumericBase): _index_cls = Float64Index _dtype = np.float64 - @pytest.fixture - def simple_index(self) -> Index: - values = np.arange(5, dtype=self._dtype) - return self._index_cls(values) + @pytest.fixture( + params=[ + (Float64Index, None), + (NumIndex, np.float64), + (NumIndex, np.float32), + ], + ) + def simple_index(self, request) -> Index: + index_cls, dtype = request.param + values = np.arange(5, dtype=dtype) + return index_cls(values) @pytest.fixture( params=[ @@ -392,9 +400,19 @@ class TestInt64Index(NumericInt): _index_cls = Int64Index _dtype = np.int64 - @pytest.fixture - def simple_index(self) -> Index: - return self._index_cls(range(0, 20, 2), dtype=self._dtype) + @pytest.fixture( + params=[ + (Int64Index, None), + (NumIndex, np.int64), + (NumIndex, np.int32), + (NumIndex, np.int16), + (NumIndex, np.int8), + ], + ) + def simple_index(self, request) -> Index: + index_cls, dtype = request.param + values = np.arange(5, dtype=dtype) + return index_cls(values) @pytest.fixture( params=[range(0, 20, 2), range(19, -1, -1)], ids=["index_inc", "index_dec"] @@ -490,10 +508,19 @@ class TestUInt64Index(NumericInt): _index_cls = UInt64Index _dtype = np.uint64 - @pytest.fixture - def simple_index(self) -> Index: - # compat with shared Int64/Float64 tests - return self._index_cls(np.arange(5, dtype=self._dtype)) + @pytest.fixture( + params=[ + (UInt64Index, None), + (NumIndex, np.uint64), + (NumIndex, np.uint32), + (NumIndex, np.uint16), + (NumIndex, np.uint8), + ], + ) + def simple_index(self, request) -> Index: + index_cls, dtype = request.param + values = np.arange(5, dtype=dtype) + return index_cls(values) @pytest.fixture( params=[