Skip to content

Commit

Permalink
Backport PR #52677 on branch 2.0.x (BUG: tz_localize with ArrowDtype) (
Browse files Browse the repository at this point in the history
…#52762)

Backport PR #52677: BUG: tz_localize with ArrowDtype
  • Loading branch information
mroeschke authored Apr 18, 2023
1 parent ef66c8e commit c10de3a
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 11 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Bug fixes
- Bug in :meth:`ArrowDtype.__from_arrow__` not respecting if dtype is explicitly given (:issue:`52533`)
- Bug in :meth:`DataFrame.max` and related casting different :class:`Timestamp` resolutions always to nanoseconds (:issue:`52524`)
- Bug in :meth:`Series.describe` not returning :class:`ArrowDtype` with ``pyarrow.float64`` type with numeric data (:issue:`52427`)
- Bug in :meth:`Series.dt.tz_localize` incorrectly localizing timestamps with :class:`ArrowDtype` (:issue:`52677`)
- Fixed bug in :func:`merge` when merging with ``ArrowDtype`` one one and a NumPy dtype on the other side (:issue:`52406`)
- Fixed segfault in :meth:`Series.to_numpy` with ``null[pyarrow]`` dtype (:issue:`52443`)

Expand Down
21 changes: 14 additions & 7 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2109,12 +2109,19 @@ def _dt_tz_localize(
):
if ambiguous != "raise":
raise NotImplementedError(f"{ambiguous=} is not supported")
if nonexistent != "raise":
nonexistent_pa = {
"raise": "raise",
"shift_backward": "earliest",
"shift_forward": "latest",
}.get(
nonexistent, None # type: ignore[arg-type]
)
if nonexistent_pa is None:
raise NotImplementedError(f"{nonexistent=} is not supported")
if tz is None:
new_type = pa.timestamp(self.dtype.pyarrow_dtype.unit)
return type(self)(self._data.cast(new_type))
pa_tz = str(tz)
return type(self)(
self._data.cast(pa.timestamp(self.dtype.pyarrow_dtype.unit, pa_tz))
)
result = self._data.cast(pa.timestamp(self.dtype.pyarrow_dtype.unit))
else:
result = pc.assume_timezone(
self._data, str(tz), ambiguous=ambiguous, nonexistent=nonexistent_pa
)
return type(self)(result)
48 changes: 44 additions & 4 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2397,16 +2397,56 @@ def test_dt_tz_localize_none():


@pytest.mark.parametrize("unit", ["us", "ns"])
def test_dt_tz_localize(unit):
def test_dt_tz_localize(unit, request):
if is_platform_windows() and is_ci_environment():
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason=(
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
"on CI to path to the tzdata for pyarrow."
),
)
)
ser = pd.Series(
[datetime(year=2023, month=1, day=2, hour=3), None],
dtype=ArrowDtype(pa.timestamp(unit)),
)
result = ser.dt.tz_localize("US/Pacific")
expected = pd.Series(
[datetime(year=2023, month=1, day=2, hour=3), None],
dtype=ArrowDtype(pa.timestamp(unit, "US/Pacific")),
exp_data = pa.array(
[datetime(year=2023, month=1, day=2, hour=3), None], type=pa.timestamp(unit)
)
exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
expected = pd.Series(ArrowExtensionArray(exp_data))
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
"nonexistent, exp_date",
[
["shift_forward", datetime(year=2023, month=3, day=12, hour=3)],
["shift_backward", pd.Timestamp("2023-03-12 01:59:59.999999999")],
],
)
def test_dt_tz_localize_nonexistent(nonexistent, exp_date, request):
if is_platform_windows() and is_ci_environment():
request.node.add_marker(
pytest.mark.xfail(
raises=pa.ArrowInvalid,
reason=(
"TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
"on CI to path to the tzdata for pyarrow."
),
)
)
ser = pd.Series(
[datetime(year=2023, month=3, day=12, hour=2, minute=30), None],
dtype=ArrowDtype(pa.timestamp("ns")),
)
result = ser.dt.tz_localize("US/Pacific", nonexistent=nonexistent)
exp_data = pa.array([exp_date, None], type=pa.timestamp("ns"))
exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
expected = pd.Series(ArrowExtensionArray(exp_data))
tm.assert_series_equal(result, expected)


Expand Down

0 comments on commit c10de3a

Please sign in to comment.