Skip to content

Commit

Permalink
Add more numeric tests for NumIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
topper-123 committed May 2, 2021
1 parent 0539d8c commit a3ec4e5
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 20 deletions.
26 changes: 23 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def __new__(
return Index._simple_new(data, name=name)

# index-like
elif type(data) is NumIndex and dtype is None:
elif isinstance(data, NumIndex) and data._is_num_index() and dtype is None:
return NumIndex(data, name=name, copy=copy)
elif isinstance(data, (np.ndarray, Index, ABCSeries)):

Expand Down Expand Up @@ -2407,6 +2407,26 @@ def is_all_dates(self) -> bool:
)
return self._is_all_dates

def _is_num_index(self) -> bool:
"""
Whether self is a NumIndex, but not *not* Int64Index, UInt64Index, FloatIndex.
Typically used to check if an operation should return NumIndex or plain Index.
"""
from pandas.core.indexes.numeric import (
Float64Index,
Int64Index,
NumIndex,
UInt64Index,
)

if not isinstance(self, NumIndex):
return False
elif isinstance(self, (Int64Index, UInt64Index, Float64Index)):
return False
else:
return True

# --------------------------------------------------------------------
# Pickle Methods

Expand Down Expand Up @@ -5488,8 +5508,8 @@ def map(self, mapper, na_action=None):
# empty
attributes["dtype"] = self.dtype

if type(self) is NumIndex:
return type(self)(new_values, **attributes)
if self._is_num_index() and issubclass(new_values.dtype.type, np.number):
return NumIndex(new_values, **attributes)

return Index(new_values, **attributes)

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,11 @@ def astype(self, dtype, copy=True):
# TODO(jreback); this can change once we have an EA Index type
# GH 13149
arr = astype_nansafe(self._values, dtype=dtype)
if isinstance(self, Float64Index):
if not self._is_num_index():
return Int64Index(arr, name=self.name)
else:
return NumIndex(arr, name=self.name, dtype=dtype)
elif is_categorical_dtype(dtype):
if is_categorical_dtype(dtype):
from pandas import CategoricalIndex

return CategoricalIndex(self, name=self.name, dtype=dtype, copy=copy)
Expand Down
15 changes: 11 additions & 4 deletions pandas/tests/indexes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pandas.core.dtypes.dtypes import CategoricalDtype

import pandas as pd
from pandas import ( # noqa
from pandas import (
CategoricalIndex,
DatetimeIndex,
Float64Index,
Expand All @@ -29,6 +29,7 @@
)
import pandas._testing as tm
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
from pandas.core.indexes.numeric import NumIndex


class Base:
Expand Down Expand Up @@ -343,12 +344,13 @@ def test_numpy_argsort(self, index):
def test_repeat(self, simple_index):
rep = 2
idx = simple_index.copy()
expected = Index(idx.values.repeat(rep), name=idx.name)
new_index_cls = type(idx) if not isinstance(idx, RangeIndex) else Int64Index
expected = new_index_cls(idx.values.repeat(rep), name=idx.name)
tm.assert_index_equal(idx.repeat(rep), expected)

idx = simple_index
rep = np.arange(len(idx))
expected = Index(idx.values.repeat(rep), name=idx.name)
expected = new_index_cls(idx.values.repeat(rep), name=idx.name)
tm.assert_index_equal(idx.repeat(rep), expected)

def test_numpy_repeat(self, simple_index):
Expand Down Expand Up @@ -649,7 +651,12 @@ def test_map_dictlike(self, mapper, simple_index):
tm.assert_index_equal(result, expected)

# empty mappable
expected = Index([np.nan] * len(idx))
if idx._is_num_index():
new_index_cls = NumIndex
else:
new_index_cls = Float64Index

expected = new_index_cls([np.nan] * len(idx))
result = idx.map(mapper(expected, idx))
tm.assert_index_equal(result, expected)

Expand Down
49 changes: 38 additions & 11 deletions pandas/tests/indexes/numeric/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,25 @@
UInt64Index,
)
import pandas._testing as tm
from pandas.core.indexes.numeric import NumIndex
from pandas.tests.indexes.common import NumericBase


class TestFloat64Index(NumericBase):
_index_cls = Float64Index
_dtype = np.float64

@pytest.fixture
def simple_index(self) -> Index:
values = np.arange(5, dtype=self._dtype)
return self._index_cls(values)
@pytest.fixture(
params=[
(Float64Index, None),
(NumIndex, np.float64),
(NumIndex, np.float32),
],
)
def simple_index(self, request) -> Index:
index_cls, dtype = request.param
values = np.arange(5, dtype=dtype)
return index_cls(values)

@pytest.fixture(
params=[
Expand Down Expand Up @@ -392,9 +400,19 @@ class TestInt64Index(NumericInt):
_index_cls = Int64Index
_dtype = np.int64

@pytest.fixture
def simple_index(self) -> Index:
return self._index_cls(range(0, 20, 2), dtype=self._dtype)
@pytest.fixture(
params=[
(Int64Index, None),
(NumIndex, np.int64),
(NumIndex, np.int32),
(NumIndex, np.int16),
(NumIndex, np.int8),
],
)
def simple_index(self, request) -> Index:
index_cls, dtype = request.param
values = np.arange(5, dtype=dtype)
return index_cls(values)

@pytest.fixture(
params=[range(0, 20, 2), range(19, -1, -1)], ids=["index_inc", "index_dec"]
Expand Down Expand Up @@ -490,10 +508,19 @@ class TestUInt64Index(NumericInt):
_index_cls = UInt64Index
_dtype = np.uint64

@pytest.fixture
def simple_index(self) -> Index:
# compat with shared Int64/Float64 tests
return self._index_cls(np.arange(5, dtype=self._dtype))
@pytest.fixture(
params=[
(UInt64Index, None),
(NumIndex, np.uint64),
(NumIndex, np.uint32),
(NumIndex, np.uint16),
(NumIndex, np.uint8),
],
)
def simple_index(self, request) -> Index:
index_cls, dtype = request.param
values = np.arange(5, dtype=dtype)
return index_cls(values)

@pytest.fixture(
params=[
Expand Down

0 comments on commit a3ec4e5

Please sign in to comment.