From c10de3a600bff44f62b63399994b5f5ff67fbf40 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Tue, 18 Apr 2023 12:58:47 -0700 Subject: [PATCH] Backport PR #52677 on branch 2.0.x (BUG: tz_localize with ArrowDtype) (#52762) Backport PR #52677: BUG: tz_localize with ArrowDtype --- doc/source/whatsnew/v2.0.1.rst | 1 + pandas/core/arrays/arrow/array.py | 21 ++++++++---- pandas/tests/extension/test_arrow.py | 48 +++++++++++++++++++++++++--- 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/doc/source/whatsnew/v2.0.1.rst b/doc/source/whatsnew/v2.0.1.rst index c64f7a46d3058..9c867544a324b 100644 --- a/doc/source/whatsnew/v2.0.1.rst +++ b/doc/source/whatsnew/v2.0.1.rst @@ -33,6 +33,7 @@ Bug fixes - Bug in :meth:`ArrowDtype.__from_arrow__` not respecting if dtype is explicitly given (:issue:`52533`) - Bug in :meth:`DataFrame.max` and related casting different :class:`Timestamp` resolutions always to nanoseconds (:issue:`52524`) - Bug in :meth:`Series.describe` not returning :class:`ArrowDtype` with ``pyarrow.float64`` type with numeric data (:issue:`52427`) +- Bug in :meth:`Series.dt.tz_localize` incorrectly localizing timestamps with :class:`ArrowDtype` (:issue:`52677`) - Fixed bug in :func:`merge` when merging with ``ArrowDtype`` one one and a NumPy dtype on the other side (:issue:`52406`) - Fixed segfault in :meth:`Series.to_numpy` with ``null[pyarrow]`` dtype (:issue:`52443`) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index c46420eb0b349..a9fb6cd801a82 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -2109,12 +2109,19 @@ def _dt_tz_localize( ): if ambiguous != "raise": raise NotImplementedError(f"{ambiguous=} is not supported") - if nonexistent != "raise": + nonexistent_pa = { + "raise": "raise", + "shift_backward": "earliest", + "shift_forward": "latest", + }.get( + nonexistent, None # type: ignore[arg-type] + ) + if nonexistent_pa is None: raise NotImplementedError(f"{nonexistent=} is not supported") if tz is None: - new_type = pa.timestamp(self.dtype.pyarrow_dtype.unit) - return type(self)(self._data.cast(new_type)) - pa_tz = str(tz) - return type(self)( - self._data.cast(pa.timestamp(self.dtype.pyarrow_dtype.unit, pa_tz)) - ) + result = self._data.cast(pa.timestamp(self.dtype.pyarrow_dtype.unit)) + else: + result = pc.assume_timezone( + self._data, str(tz), ambiguous=ambiguous, nonexistent=nonexistent_pa + ) + return type(self)(result) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index da1362a7bd26b..c0cde489c15a3 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2397,16 +2397,56 @@ def test_dt_tz_localize_none(): @pytest.mark.parametrize("unit", ["us", "ns"]) -def test_dt_tz_localize(unit): +def test_dt_tz_localize(unit, request): + if is_platform_windows() and is_ci_environment(): + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason=( + "TODO: Set ARROW_TIMEZONE_DATABASE environment variable " + "on CI to path to the tzdata for pyarrow." + ), + ) + ) ser = pd.Series( [datetime(year=2023, month=1, day=2, hour=3), None], dtype=ArrowDtype(pa.timestamp(unit)), ) result = ser.dt.tz_localize("US/Pacific") - expected = pd.Series( - [datetime(year=2023, month=1, day=2, hour=3), None], - dtype=ArrowDtype(pa.timestamp(unit, "US/Pacific")), + exp_data = pa.array( + [datetime(year=2023, month=1, day=2, hour=3), None], type=pa.timestamp(unit) + ) + exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific") + expected = pd.Series(ArrowExtensionArray(exp_data)) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "nonexistent, exp_date", + [ + ["shift_forward", datetime(year=2023, month=3, day=12, hour=3)], + ["shift_backward", pd.Timestamp("2023-03-12 01:59:59.999999999")], + ], +) +def test_dt_tz_localize_nonexistent(nonexistent, exp_date, request): + if is_platform_windows() and is_ci_environment(): + request.node.add_marker( + pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason=( + "TODO: Set ARROW_TIMEZONE_DATABASE environment variable " + "on CI to path to the tzdata for pyarrow." + ), + ) + ) + ser = pd.Series( + [datetime(year=2023, month=3, day=12, hour=2, minute=30), None], + dtype=ArrowDtype(pa.timestamp("ns")), ) + result = ser.dt.tz_localize("US/Pacific", nonexistent=nonexistent) + exp_data = pa.array([exp_date, None], type=pa.timestamp("ns")) + exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific") + expected = pd.Series(ArrowExtensionArray(exp_data)) tm.assert_series_equal(result, expected)