Skip to content

Commit 280a88f

Browse files
jorisvandenbosschejreback
authored andcommitted
API: consistent __array__ for datetime-like ExtensionArrays (#23593)
1 parent c2d4a1a commit 280a88f

File tree

6 files changed

+139
-23
lines changed

6 files changed

+139
-23
lines changed

pandas/core/arrays/datetimelike.py

+6
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,12 @@ def _formatter(self, boxed=False):
391391
def nbytes(self):
392392
return self._data.nbytes
393393

394+
def __array__(self, dtype=None):
395+
# used for Timedelta/DatetimeArray, overwritten by PeriodArray
396+
if is_object_dtype(dtype):
397+
return np.array(list(self), dtype=object)
398+
return self._data
399+
394400
@property
395401
def shape(self):
396402
return (len(self),)

pandas/core/arrays/datetimes.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from pandas.core.dtypes.common import (
1717
_INT64_DTYPE, _NS_DTYPE, is_categorical_dtype, is_datetime64_dtype,
1818
is_datetime64_ns_dtype, is_datetime64tz_dtype, is_dtype_equal,
19-
is_extension_type, is_float_dtype, is_int64_dtype, is_object_dtype,
20-
is_period_dtype, is_string_dtype, is_timedelta64_dtype, pandas_dtype)
19+
is_extension_type, is_float_dtype, is_object_dtype, is_period_dtype,
20+
is_string_dtype, is_timedelta64_dtype, pandas_dtype)
2121
from pandas.core.dtypes.dtypes import DatetimeTZDtype
2222
from pandas.core.dtypes.generic import ABCIndexClass, ABCPandasArray, ABCSeries
2323
from pandas.core.dtypes.missing import isna
@@ -524,12 +524,11 @@ def _resolution(self):
524524
# Array-Like / EA-Interface Methods
525525

526526
def __array__(self, dtype=None):
527-
if is_object_dtype(dtype) or (dtype is None and self.tz):
528-
return np.array(list(self), dtype=object)
529-
elif is_int64_dtype(dtype):
530-
return self.asi8
527+
if dtype is None and self.tz:
528+
# The default for tz-aware is object, to preserve tz info
529+
dtype = object
531530

532-
return self._data
531+
return super(DatetimeArray, self).__array__(dtype=dtype)
533532

534533
def __iter__(self):
535534
"""

pandas/core/arrays/period.py

+4
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ def freq(self):
282282
"""
283283
return self.dtype.freq
284284

285+
def __array__(self, dtype=None):
286+
# overriding DatetimelikeArray
287+
return np.array(list(self), dtype=object)
288+
285289
# --------------------------------------------------------------------
286290
# Vectorized analogues of Period properties
287291

pandas/core/arrays/timedeltas.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from pandas.core.dtypes.common import (
1818
_NS_DTYPE, _TD_DTYPE, ensure_int64, is_datetime64_dtype, is_float_dtype,
19-
is_int64_dtype, is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
19+
is_integer_dtype, is_list_like, is_object_dtype, is_scalar,
2020
is_string_dtype, is_timedelta64_dtype, is_timedelta64_ns_dtype,
2121
pandas_dtype)
2222
from pandas.core.dtypes.dtypes import DatetimeTZDtype
@@ -265,16 +265,6 @@ def _maybe_clear_freq(self):
265265
# ----------------------------------------------------------------
266266
# Array-Like / EA-Interface Methods
267267

268-
def __array__(self, dtype=None):
269-
# TODO(https://github.com/pandas-dev/pandas/pull/23593)
270-
# Maybe push to parent once datetimetz __array__ is figured out.
271-
if is_object_dtype(dtype):
272-
return np.array(list(self), dtype=object)
273-
elif is_int64_dtype(dtype):
274-
return self.asi8
275-
276-
return self._data
277-
278268
@Appender(dtl.DatetimeLikeArrayMixin._validate_fill_value.__doc__)
279269
def _validate_fill_value(self, fill_value):
280270
if isna(fill_value):

pandas/tests/arrays/test_datetimelike.py

+117-5
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,48 @@ def test_round(self, tz_naive_fixture):
240240
expected = dti - pd.Timedelta(minutes=1)
241241
tm.assert_index_equal(result, expected)
242242

243+
def test_array_interface(self, datetime_index):
244+
arr = DatetimeArray(datetime_index)
245+
246+
# default asarray gives the same underlying data (for tz naive)
247+
result = np.asarray(arr)
248+
expected = arr._data
249+
assert result is expected
250+
tm.assert_numpy_array_equal(result, expected)
251+
result = np.array(arr, copy=False)
252+
assert result is expected
253+
tm.assert_numpy_array_equal(result, expected)
254+
255+
# specifying M8[ns] gives the same result as default
256+
result = np.asarray(arr, dtype='datetime64[ns]')
257+
expected = arr._data
258+
assert result is expected
259+
tm.assert_numpy_array_equal(result, expected)
260+
result = np.array(arr, dtype='datetime64[ns]', copy=False)
261+
assert result is expected
262+
tm.assert_numpy_array_equal(result, expected)
263+
result = np.array(arr, dtype='datetime64[ns]')
264+
assert result is not expected
265+
tm.assert_numpy_array_equal(result, expected)
266+
267+
# to object dtype
268+
result = np.asarray(arr, dtype=object)
269+
expected = np.array(list(arr), dtype=object)
270+
tm.assert_numpy_array_equal(result, expected)
271+
272+
# to other dtype always copies
273+
result = np.asarray(arr, dtype='int64')
274+
assert result is not arr.asi8
275+
assert not np.may_share_memory(arr, result)
276+
expected = arr.asi8.copy()
277+
tm.assert_numpy_array_equal(result, expected)
278+
279+
# other dtypes handled by numpy
280+
for dtype in ['float64', str]:
281+
result = np.asarray(arr, dtype=dtype)
282+
expected = np.asarray(arr).astype(dtype)
283+
tm.assert_numpy_array_equal(result, expected)
284+
243285
def test_array_object_dtype(self, tz_naive_fixture):
244286
# GH#23524
245287
tz = tz_naive_fixture
@@ -255,7 +297,7 @@ def test_array_object_dtype(self, tz_naive_fixture):
255297
result = np.array(dti, dtype=object)
256298
tm.assert_numpy_array_equal(result, expected)
257299

258-
def test_array(self, tz_naive_fixture):
300+
def test_array_tz(self, tz_naive_fixture):
259301
# GH#23524
260302
tz = tz_naive_fixture
261303
dti = pd.date_range('2016-01-01', periods=3, tz=tz)
@@ -265,13 +307,18 @@ def test_array(self, tz_naive_fixture):
265307
result = np.array(arr, dtype='M8[ns]')
266308
tm.assert_numpy_array_equal(result, expected)
267309

310+
result = np.array(arr, dtype='datetime64[ns]')
311+
tm.assert_numpy_array_equal(result, expected)
312+
268313
# check that we are not making copies when setting copy=False
269314
result = np.array(arr, dtype='M8[ns]', copy=False)
270315
assert result.base is expected.base
271316
assert result.base is not None
317+
result = np.array(arr, dtype='datetime64[ns]', copy=False)
318+
assert result.base is expected.base
319+
assert result.base is not None
272320

273321
def test_array_i8_dtype(self, tz_naive_fixture):
274-
# GH#23524
275322
tz = tz_naive_fixture
276323
dti = pd.date_range('2016-01-01', periods=3, tz=tz)
277324
arr = DatetimeArray(dti)
@@ -283,10 +330,10 @@ def test_array_i8_dtype(self, tz_naive_fixture):
283330
result = np.array(arr, dtype=np.int64)
284331
tm.assert_numpy_array_equal(result, expected)
285332

286-
# check that we are not making copies when setting copy=False
333+
# check that we are still making copies when setting copy=False
287334
result = np.array(arr, dtype='i8', copy=False)
288-
assert result.base is expected.base
289-
assert result.base is not None
335+
assert result.base is not expected.base
336+
assert result.base is None
290337

291338
def test_from_array_keeps_base(self):
292339
# Ensure that DatetimeArray._data.base isn't lost.
@@ -470,6 +517,48 @@ def test_int_properties(self, timedelta_index, propname):
470517

471518
tm.assert_numpy_array_equal(result, expected)
472519

520+
def test_array_interface(self, timedelta_index):
521+
arr = TimedeltaArray(timedelta_index)
522+
523+
# default asarray gives the same underlying data
524+
result = np.asarray(arr)
525+
expected = arr._data
526+
assert result is expected
527+
tm.assert_numpy_array_equal(result, expected)
528+
result = np.array(arr, copy=False)
529+
assert result is expected
530+
tm.assert_numpy_array_equal(result, expected)
531+
532+
# specifying m8[ns] gives the same result as default
533+
result = np.asarray(arr, dtype='timedelta64[ns]')
534+
expected = arr._data
535+
assert result is expected
536+
tm.assert_numpy_array_equal(result, expected)
537+
result = np.array(arr, dtype='timedelta64[ns]', copy=False)
538+
assert result is expected
539+
tm.assert_numpy_array_equal(result, expected)
540+
result = np.array(arr, dtype='timedelta64[ns]')
541+
assert result is not expected
542+
tm.assert_numpy_array_equal(result, expected)
543+
544+
# to object dtype
545+
result = np.asarray(arr, dtype=object)
546+
expected = np.array(list(arr), dtype=object)
547+
tm.assert_numpy_array_equal(result, expected)
548+
549+
# to other dtype always copies
550+
result = np.asarray(arr, dtype='int64')
551+
assert result is not arr.asi8
552+
assert not np.may_share_memory(arr, result)
553+
expected = arr.asi8.copy()
554+
tm.assert_numpy_array_equal(result, expected)
555+
556+
# other dtypes handled by numpy
557+
for dtype in ['float64', str]:
558+
result = np.asarray(arr, dtype=dtype)
559+
expected = np.asarray(arr).astype(dtype)
560+
tm.assert_numpy_array_equal(result, expected)
561+
473562
def test_take_fill_valid(self, timedelta_index):
474563
tdi = timedelta_index
475564
arr = TimedeltaArray(tdi)
@@ -543,3 +632,26 @@ def test_int_properties(self, period_index, propname):
543632
expected = np.array(getattr(pi, propname))
544633

545634
tm.assert_numpy_array_equal(result, expected)
635+
636+
def test_array_interface(self, period_index):
637+
arr = PeriodArray(period_index)
638+
639+
# default asarray gives objects
640+
result = np.asarray(arr)
641+
expected = np.array(list(arr), dtype=object)
642+
tm.assert_numpy_array_equal(result, expected)
643+
644+
# to object dtype (same as default)
645+
result = np.asarray(arr, dtype=object)
646+
tm.assert_numpy_array_equal(result, expected)
647+
648+
# to other dtypes
649+
with pytest.raises(TypeError):
650+
np.asarray(arr, dtype='int64')
651+
652+
with pytest.raises(TypeError):
653+
np.asarray(arr, dtype='float64')
654+
655+
result = np.asarray(arr, dtype='S20')
656+
expected = np.asarray(arr).astype('S20')
657+
tm.assert_numpy_array_equal(result, expected)

pandas/tests/extension/base/interface.py

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pandas.core.dtypes.dtypes import ExtensionDtype
55

66
import pandas as pd
7+
import pandas.util.testing as tm
78

89
from .base import BaseExtensionTests
910

@@ -33,6 +34,10 @@ def test_array_interface(self, data):
3334
result = np.array(data)
3435
assert result[0] == data[0]
3536

37+
result = np.array(data, dtype=object)
38+
expected = np.array(list(data), dtype=object)
39+
tm.assert_numpy_array_equal(result, expected)
40+
3641
def test_is_extension_array_dtype(self, data):
3742
assert is_extension_array_dtype(data)
3843
assert is_extension_array_dtype(data.dtype)

0 commit comments

Comments
 (0)