Skip to content

PERF: use fastpaths for is_period_dtype checks #33937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pandas/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,11 +1519,13 @@ def box_expected(expected, box_cls, transpose=True):

def to_array(obj):
# temporary implementation until we get pd.array in place
if is_period_dtype(obj):
dtype = getattr(obj, "dtype", None)

if is_period_dtype(dtype):
return period_array(obj)
elif is_datetime64_dtype(obj) or is_datetime64tz_dtype(obj):
elif is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype):
return DatetimeArray._from_sequence(obj)
elif is_timedelta64_dtype(obj):
elif is_timedelta64_dtype(dtype):
return TimedeltaArray._from_sequence(obj)
else:
return np.array(obj)
Expand Down
17 changes: 9 additions & 8 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from pandas._libs import Timestamp, algos, hashtable as htable, lib
from pandas._libs.tslib import iNaT
from pandas._typing import AnyArrayLike
from pandas._typing import AnyArrayLike, DtypeObj
from pandas.util._decorators import doc

from pandas.core.dtypes.cast import (
Expand Down Expand Up @@ -126,20 +126,21 @@ def _ensure_data(values, dtype=None):
return ensure_object(values), "object"

# datetimelike
if needs_i8_conversion(values) or needs_i8_conversion(dtype):
if is_period_dtype(values) or is_period_dtype(dtype):
vals_dtype = getattr(values, "dtype", None)
if needs_i8_conversion(vals_dtype) or needs_i8_conversion(dtype):
if is_period_dtype(vals_dtype) or is_period_dtype(dtype):
from pandas import PeriodIndex

values = PeriodIndex(values)
dtype = values.dtype
elif is_timedelta64_dtype(values) or is_timedelta64_dtype(dtype):
elif is_timedelta64_dtype(vals_dtype) or is_timedelta64_dtype(dtype):
from pandas import TimedeltaIndex

values = TimedeltaIndex(values)
dtype = values.dtype
else:
# Datetime
if values.ndim > 1 and is_datetime64_ns_dtype(values):
if values.ndim > 1 and is_datetime64_ns_dtype(vals_dtype):
# Avoid calling the DatetimeIndex constructor as it is 1D only
# Note: this is reached by DataFrame.rank calls GH#27027
# TODO(EA2D): special case not needed with 2D EAs
Expand All @@ -154,7 +155,7 @@ def _ensure_data(values, dtype=None):

return values.asi8, dtype

elif is_categorical_dtype(values) and (
elif is_categorical_dtype(vals_dtype) and (
is_categorical_dtype(dtype) or dtype is None
):
values = values.codes
Expand Down Expand Up @@ -1080,7 +1081,7 @@ def nsmallest(self):
return self.compute("nsmallest")

@staticmethod
def is_valid_dtype_n_method(dtype) -> bool:
def is_valid_dtype_n_method(dtype: DtypeObj) -> bool:
"""
Helper function to determine if dtype is valid for
nsmallest/nlargest methods
Expand Down Expand Up @@ -1863,7 +1864,7 @@ def diff(arr, n: int, axis: int = 0, stacklevel=3):

is_timedelta = False
is_bool = False
if needs_i8_conversion(arr):
if needs_i8_conversion(arr.dtype):
dtype = np.float64
arr = arr.view("i8")
na = iNaT
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,7 @@ def _internal_get_values(self):
Index if datetime / periods.
"""
# if we are a datetime and period index, return Index to keep metadata
if needs_i8_conversion(self.categories):
if needs_i8_conversion(self.categories.dtype):
return self.categories.take(self._codes, fill_value=np.nan)
elif is_integer_dtype(self.categories) and -1 in self._codes:
return self.categories.astype("object").take(self._codes, fill_value=np.nan)
Expand Down
35 changes: 17 additions & 18 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def _validate_searchsorted_value(self, value):
elif is_list_like(value) and not isinstance(value, type(self)):
value = array(value)

if not type(self)._is_recognized_dtype(value):
if not type(self)._is_recognized_dtype(value.dtype):
raise TypeError(
"searchsorted requires compatible dtype or scalar, "
f"not {type(value).__name__}"
Expand All @@ -806,7 +806,7 @@ def _validate_setitem_value(self, value):
except ValueError:
pass

if not type(self)._is_recognized_dtype(value):
if not type(self)._is_recognized_dtype(value.dtype):
raise TypeError(
"setitem requires compatible dtype or scalar, "
f"not {type(value).__name__}"
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def fillna(self, value=None, method=None, limit=None):
func = missing.backfill_1d

values = self._data
if not is_period_dtype(self):
if not is_period_dtype(self.dtype):
# For PeriodArray self._data is i8, which gets copied
# by `func`. Otherwise we need to make a copy manually
# to avoid modifying `self` in-place.
Expand Down Expand Up @@ -1109,10 +1109,7 @@ def _validate_frequency(cls, index, freq, **kwargs):
freq : DateOffset
The frequency to validate
"""
if is_period_dtype(cls):
# Frequency validation is not meaningful for Period Array/Index
return None

# TODO: this is not applicable to PeriodArray, move to correct Mixin
inferred = index.inferred_freq
if index.size == 0 or inferred == freq.freqstr:
return None
Expand Down Expand Up @@ -1253,7 +1250,7 @@ def _add_nat(self):
"""
Add pd.NaT to self
"""
if is_period_dtype(self):
if is_period_dtype(self.dtype):
raise TypeError(
f"Cannot add {type(self).__name__} and {type(NaT).__name__}"
)
Expand Down Expand Up @@ -1293,7 +1290,7 @@ def _sub_period_array(self, other):
result : np.ndarray[object]
Array of DateOffset objects; nulls represented by NaT.
"""
if not is_period_dtype(self):
if not is_period_dtype(self.dtype):
raise TypeError(
f"cannot subtract {other.dtype}-dtype from {type(self).__name__}"
)
Expand Down Expand Up @@ -1398,7 +1395,7 @@ def __add__(self, other):
elif lib.is_integer(other):
# This check must come after the check for np.timedelta64
# as is_integer returns True for these
if not is_period_dtype(self):
if not is_period_dtype(self.dtype):
raise integer_op_not_supported(self)
result = self._time_shift(other)

Expand All @@ -1413,7 +1410,7 @@ def __add__(self, other):
# DatetimeIndex, ndarray[datetime64]
return self._add_datetime_arraylike(other)
elif is_integer_dtype(other):
if not is_period_dtype(self):
if not is_period_dtype(self.dtype):
raise integer_op_not_supported(self)
result = self._addsub_int_array(other, operator.add)
else:
Expand All @@ -1437,6 +1434,8 @@ def __radd__(self, other):
@unpack_zerodim_and_defer("__sub__")
def __sub__(self, other):

other_dtype = getattr(other, "dtype", None)

# scalar others
if other is NaT:
result = self._sub_nat()
Expand All @@ -1450,7 +1449,7 @@ def __sub__(self, other):
elif lib.is_integer(other):
# This check must come after the check for np.timedelta64
# as is_integer returns True for these
if not is_period_dtype(self):
if not is_period_dtype(self.dtype):
raise integer_op_not_supported(self)
result = self._time_shift(-other)

Expand All @@ -1467,11 +1466,11 @@ def __sub__(self, other):
elif is_datetime64_dtype(other) or is_datetime64tz_dtype(other):
# DatetimeIndex, ndarray[datetime64]
result = self._sub_datetime_arraylike(other)
elif is_period_dtype(other):
elif is_period_dtype(other_dtype):
# PeriodIndex
result = self._sub_period_array(other)
elif is_integer_dtype(other):
if not is_period_dtype(self):
elif is_integer_dtype(other_dtype):
if not is_period_dtype(self.dtype):
raise integer_op_not_supported(self)
result = self._addsub_int_array(other, operator.sub)
else:
Expand Down Expand Up @@ -1520,7 +1519,7 @@ def __iadd__(self, other):
result = self + other
self[:] = result[:]

if not is_period_dtype(self):
if not is_period_dtype(self.dtype):
# restore freq, which is invalidated by setitem
self._freq = result._freq
return self
Expand All @@ -1529,7 +1528,7 @@ def __isub__(self, other):
result = self - other
self[:] = result[:]

if not is_period_dtype(self):
if not is_period_dtype(self.dtype):
# restore freq, which is invalidated by setitem
self._freq = result._freq
return self
Expand Down Expand Up @@ -1621,7 +1620,7 @@ def mean(self, skipna=True):
-----
mean is only defined for Datetime and Timedelta dtypes, not for Period.
"""
if is_period_dtype(self):
if is_period_dtype(self.dtype):
# See discussion in GH#24757
raise TypeError(
f"mean is not implemented for {type(self).__name__} since the "
Expand Down
6 changes: 4 additions & 2 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,9 +828,11 @@ def period_array(
['2000Q1', '2000Q2', '2000Q3', '2000Q4']
Length: 4, dtype: period[Q-DEC]
"""
if is_datetime64_dtype(data):
data_dtype = getattr(data, "dtype", None)

if is_datetime64_dtype(data_dtype):
return PeriodArray._from_datetime64(data, freq)
if is_period_dtype(data):
if is_period_dtype(data_dtype):
return PeriodArray(data, freq)

# other iterable of some kind
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/dtypes/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool:
return True

# NaNs can occur in float and complex arrays.
if is_float_dtype(left) or is_complex_dtype(left):
if is_float_dtype(left.dtype) or is_complex_dtype(left.dtype):

# empty
if not (np.prod(left.shape) and np.prod(right.shape)):
Expand All @@ -435,7 +435,7 @@ def array_equivalent(left, right, strict_nan: bool = False) -> bool:
# GH#29553 avoid numpy deprecation warning
return False

elif needs_i8_conversion(left) or needs_i8_conversion(right):
elif needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype):
# datetime64, timedelta64, Period
if not is_dtype_equal(left.dtype, right.dtype):
return False
Expand All @@ -460,7 +460,7 @@ def _infer_fill_value(val):
if not is_list_like(val):
val = [val]
val = np.array(val, copy=False)
if needs_i8_conversion(val):
if needs_i8_conversion(val.dtype):
return np.array("NaT", dtype=val.dtype)
elif is_object_dtype(val.dtype):
dtype = lib.infer_dtype(ensure_object(val), skipna=False)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5916,7 +5916,7 @@ def extract_values(arr):
if isinstance(arr, (ABCIndexClass, ABCSeries)):
arr = arr._values

if needs_i8_conversion(arr):
if needs_i8_conversion(arr.dtype):
if is_extension_array_dtype(arr.dtype):
arr = arr.asi8
else:
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/indexes/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _get_values(self):
elif is_timedelta64_dtype(data.dtype):
return TimedeltaIndex(data, copy=False, name=self.name)

elif is_period_dtype(data):
elif is_period_dtype(data.dtype):
return PeriodArray(data, copy=False)

raise TypeError(
Expand Down Expand Up @@ -449,7 +449,7 @@ def __new__(cls, data: "Series"):
return DatetimeProperties(data, orig)
elif is_timedelta64_dtype(data.dtype):
return TimedeltaProperties(data, orig)
elif is_period_dtype(data):
elif is_period_dtype(data.dtype):
return PeriodProperties(data, orig)

raise AttributeError("Can only use .dt accessor with datetimelike values")
4 changes: 2 additions & 2 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __array_wrap__(self, result, context=None):
return result

attrs = self._get_attributes_dict()
if not is_period_dtype(self) and attrs["freq"]:
if not is_period_dtype(self.dtype) and attrs["freq"]:
# no need to infer if freq is None
attrs["freq"] = "infer"
return Index(result, **attrs)
Expand Down Expand Up @@ -542,7 +542,7 @@ def delete(self, loc):
new_i8s = np.delete(self.asi8, loc)

freq = None
if is_period_dtype(self):
if is_period_dtype(self.dtype):
freq = self.freq
elif is_integer(loc):
if loc in (0, -len(self), -1, len(self) - 1):
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2673,13 +2673,13 @@ def get_block_type(values, dtype=None):
elif is_categorical_dtype(values.dtype):
cls = CategoricalBlock
elif issubclass(vtype, np.datetime64):
assert not is_datetime64tz_dtype(values)
assert not is_datetime64tz_dtype(values.dtype)
cls = DatetimeBlock
elif is_datetime64tz_dtype(values):
elif is_datetime64tz_dtype(values.dtype):
cls = DatetimeTZBlock
elif is_interval_dtype(dtype) or is_period_dtype(dtype):
cls = ObjectValuesExtensionBlock
elif is_extension_array_dtype(values):
elif is_extension_array_dtype(values.dtype):
cls = ExtensionBlock
elif issubclass(vtype, np.floating):
cls = FloatBlock
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def interpolate_1d(
if method in ("values", "index"):
inds = np.asarray(xvalues)
# hack for DatetimeIndex, #1646
if needs_i8_conversion(inds.dtype.type):
if needs_i8_conversion(inds.dtype):
inds = inds.view(np.int64)
if inds.dtype == np.object_:
inds = lib.maybe_convert_objects(inds)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _get_values(

dtype = values.dtype

if needs_i8_conversion(values):
if needs_i8_conversion(values.dtype):
# changing timedelta64/datetime64 to int64 needs to happen after
# finding `mask` above
values = np.asarray(values.view("i8"))
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/reshape/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ def get_new_values(self, values, fill_value=None):
# we need to convert to a basic dtype
# and possibly coerce an input to our output dtype
# e.g. ints -> floats
if needs_i8_conversion(values):
if needs_i8_conversion(values.dtype):
sorted_values = sorted_values.view("i8")
new_values = new_values.view("i8")
elif is_bool_dtype(values):
elif is_bool_dtype(values.dtype):
sorted_values = sorted_values.astype("object")
new_values = new_values.astype("object")
else:
Expand All @@ -253,7 +253,7 @@ def get_new_values(self, values, fill_value=None):
)

# reconstruct dtype if needed
if needs_i8_conversion(values):
if needs_i8_conversion(values.dtype):
new_values = new_values.view(values.dtype)

return new_values, new_mask
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/json/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(
if len(timedeltas):
obj[timedeltas] = obj[timedeltas].applymap(lambda x: x.isoformat())
# Convert PeriodIndex to datetimes before serializing
if is_period_dtype(obj.index):
if is_period_dtype(obj.index.dtype):
obj.index = obj.index.to_timestamp()

# exclude index from obj if index=False
Expand Down
Loading