Skip to content

Commit

Permalink
REF: enforce annotation in maybe_downcast_to_dtype (#40982)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Apr 16, 2021
1 parent ece1217 commit 4340689
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 48 deletions.
28 changes: 4 additions & 24 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import annotations

from contextlib import suppress
from datetime import (
date,
datetime,
Expand All @@ -29,7 +28,6 @@
NaT,
OutOfBoundsDatetime,
OutOfBoundsTimedelta,
Period,
Timedelta,
Timestamp,
conversion,
Expand Down Expand Up @@ -87,7 +85,6 @@
PeriodDtype,
)
from pandas.core.dtypes.generic import (
ABCDataFrame,
ABCExtensionArray,
ABCSeries,
)
Expand Down Expand Up @@ -249,9 +246,6 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
try to cast to the specified dtype (e.g. convert back to bool/int
or could be an astype of float64->float32
"""
if isinstance(result, ABCDataFrame):
# see test_pivot_table_doctest_case
return result
do_round = False

if isinstance(dtype, str):
Expand All @@ -278,15 +272,9 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi

dtype = np.dtype(dtype)

elif dtype.type is Period:
from pandas.core.arrays import PeriodArray

with suppress(TypeError):
# e.g. TypeError: int() argument must be a string, a
# bytes-like object or a number, not 'Period

# error: "dtype[Any]" has no attribute "freq"
return PeriodArray(result, freq=dtype.freq) # type: ignore[attr-defined]
if not isinstance(dtype, np.dtype):
# enforce our signature annotation
raise TypeError(dtype) # pragma: no cover

converted = maybe_downcast_numeric(result, dtype, do_round)
if converted is not result:
Expand All @@ -295,15 +283,7 @@ def maybe_downcast_to_dtype(result: ArrayLike, dtype: str | np.dtype) -> ArrayLi
# a datetimelike
# GH12821, iNaT is cast to float
if dtype.kind in ["M", "m"] and result.dtype.kind in ["i", "f"]:
if isinstance(dtype, DatetimeTZDtype):
# convert to datetime and change timezone
i8values = result.astype("i8", copy=False)
cls = dtype.construct_array_type()
# equiv: DatetimeArray(i8values).tz_localize("UTC").tz_convert(dtype.tz)
dt64values = i8values.view("M8[ns]")
result = cls._simple_new(dt64values, dtype=dtype)
else:
result = result.astype(dtype)
result = result.astype(dtype)

return result

Expand Down
11 changes: 6 additions & 5 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7213,13 +7213,14 @@ def combine(
else:
# if we have different dtypes, possibly promote
new_dtype = find_common_type([this_dtype, other_dtype])
if not is_dtype_equal(this_dtype, new_dtype):
series = series.astype(new_dtype)
if not is_dtype_equal(other_dtype, new_dtype):
otherSeries = otherSeries.astype(new_dtype)
series = series.astype(new_dtype, copy=False)
otherSeries = otherSeries.astype(new_dtype, copy=False)

arr = func(series, otherSeries)
arr = maybe_downcast_to_dtype(arr, new_dtype)
if isinstance(new_dtype, np.dtype):
# if new_dtype is an EA Dtype, then `func` is expected to return
# the correct dtype without any additional casting
arr = maybe_downcast_to_dtype(arr, new_dtype)

result[col] = arr

Expand Down
10 changes: 9 additions & 1 deletion pandas/core/reshape/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,15 @@ def __internal_pivot_table(
and v in agged
and not is_integer_dtype(agged[v])
):
agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)
if isinstance(agged[v], ABCDataFrame):
# exclude DataFrame case bc maybe_downcast_to_dtype expects
# ArrayLike
# TODO: why does test_pivot_table_doctest_case fail if
# we don't do this apparently-unnecessary setitem?
agged[v] = agged[v]
pass
else:
agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)

table = agged

Expand Down
20 changes: 2 additions & 18 deletions pandas/tests/dtypes/cast/test_downcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@

from pandas.core.dtypes.cast import maybe_downcast_to_dtype

from pandas import (
DatetimeIndex,
Series,
Timestamp,
)
from pandas import Series
import pandas._testing as tm


Expand Down Expand Up @@ -77,7 +73,7 @@ def test_downcast_conversion_nan(float_dtype):
def test_downcast_conversion_empty(any_real_dtype):
dtype = any_real_dtype
arr = np.array([], dtype=dtype)
result = maybe_downcast_to_dtype(arr, "int64")
result = maybe_downcast_to_dtype(arr, np.dtype("int64"))
tm.assert_numpy_array_equal(result, np.array([], dtype=np.int64))


Expand All @@ -89,15 +85,3 @@ def test_datetime_likes_nan(klass):
exp = np.array([1, 2, klass("NaT")], dtype)
res = maybe_downcast_to_dtype(arr, dtype)
tm.assert_numpy_array_equal(res, exp)


@pytest.mark.parametrize("as_asi", [True, False])
def test_datetime_with_timezone(as_asi):
# see gh-15426
ts = Timestamp("2016-01-01 12:00:00", tz="US/Pacific")
exp = DatetimeIndex([ts, ts])._data

obj = exp.asi8 if as_asi else exp
res = maybe_downcast_to_dtype(obj, exp.dtype)

tm.assert_datetime_array_equal(res, exp)

0 comments on commit 4340689

Please sign in to comment.