Skip to content

Commit

Permalink
ENH: general concat with ExtensionArrays through find_common_type (pa…
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche authored and rhshadrach committed May 10, 2020
1 parent 35d738d commit 96101f6
Show file tree
Hide file tree
Showing 15 changed files with 209 additions and 131 deletions.
3 changes: 3 additions & 0 deletions doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,9 @@ Backwards incompatible API changes
- :meth:`DataFrame.at` and :meth:`Series.at` will raise a ``TypeError`` instead of a ``ValueError`` if an incompatible key is passed, and ``KeyError`` if a missing key is passed, matching the behavior of ``.loc[]`` (:issue:`31722`)
- Passing an integer dtype other than ``int64`` to ``np.array(period_index, dtype=...)`` will now raise ``TypeError`` instead of incorrectly using ``int64`` (:issue:`32255`)
- Passing an invalid ``fill_value`` to :meth:`Categorical.take` raises a ``ValueError`` instead of ``TypeError`` (:issue:`33660`)
- Combining a ``Categorical`` with integer categories and which contains missing values
with a float dtype column in operations such as :func:`concat` or :meth:`~DataFrame.append`
will now result in a float column instead of an object dtyped column (:issue:`33607`)

``MultiIndex.get_indexer`` interprets `method` argument differently
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
7 changes: 6 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def _concat_same_type(
cls, to_concat: Sequence["ExtensionArray"]
) -> "ExtensionArray":
"""
Concatenate multiple array.
Concatenate multiple array of this dtype.
Parameters
----------
Expand All @@ -1014,6 +1014,11 @@ def _concat_same_type(
-------
ExtensionArray
"""
# Implementer note: this method will only be called with a sequence of
# ExtensionArrays of this class and with the same dtype as self. This
# should allow "easy" concatenation (no upcasting needed), and result
# in a new ExtensionArray of the same dtype.
# Note: this strict behaviour is only guaranteed starting with pandas 1.1
raise AbstractMethodError(cls)

# The _can_hold_na attribute is set to True so that pandas internals
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,9 +2296,9 @@ def _can_hold_na(self):

@classmethod
def _concat_same_type(self, to_concat):
from pandas.core.dtypes.concat import concat_categorical
from pandas.core.dtypes.concat import union_categoricals

return concat_categorical(to_concat)
return union_categoricals(to_concat)

def isin(self, values):
"""
Expand Down
15 changes: 13 additions & 2 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numbers
from typing import TYPE_CHECKING, Tuple, Type, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
import warnings

import numpy as np

from pandas._libs import lib, missing as libmissing
from pandas._typing import ArrayLike
from pandas._typing import ArrayLike, DtypeObj
from pandas.compat import set_function_name
from pandas.compat.numpy import function as nv
from pandas.util._decorators import cache_readonly
Expand Down Expand Up @@ -96,6 +96,17 @@ def construct_array_type(cls) -> Type["IntegerArray"]:
"""
return IntegerArray

def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
# for now only handle other integer types
if not all(isinstance(t, _IntegerDtype) for t in dtypes):
return None
np_dtype = np.find_common_type(
[t.numpy_dtype for t in dtypes], [] # type: ignore
)
if np.issubdtype(np_dtype, np.integer):
return _dtypes[str(np_dtype)]
return None

def __from_arrow__(
self, array: Union["pyarrow.Array", "pyarrow.ChunkedArray"]
) -> "IntegerArray":
Expand Down
22 changes: 1 addition & 21 deletions pandas/core/arrays/sparse/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,27 +952,7 @@ def copy(self):

@classmethod
def _concat_same_type(cls, to_concat):
fill_values = [x.fill_value for x in to_concat]

fill_value = fill_values[0]

# np.nan isn't a singleton, so we may end up with multiple
# NaNs here, so we ignore tha all NA case too.
if not (len(set(fill_values)) == 1 or isna(fill_values).all()):
warnings.warn(
"Concatenating sparse arrays with multiple fill "
f"values: '{fill_values}'. Picking the first and "
"converting the rest.",
PerformanceWarning,
stacklevel=6,
)
keep = to_concat[0]
to_concat2 = [keep]

for arr in to_concat[1:]:
to_concat2.append(cls(np.asarray(arr), fill_value=fill_value))

to_concat = to_concat2
fill_value = to_concat[0].fill_value

values = []
length = 0
Expand Down
26 changes: 24 additions & 2 deletions pandas/core/arrays/sparse/dtype.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Sparse Dtype"""

import re
from typing import TYPE_CHECKING, Any, Tuple, Type
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type
import warnings

import numpy as np

from pandas._typing import Dtype
from pandas._typing import Dtype, DtypeObj
from pandas.errors import PerformanceWarning

from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.cast import astype_nansafe
Expand Down Expand Up @@ -352,3 +354,23 @@ def _subtype_with_str(self):
if isinstance(self.fill_value, str):
return type(self.fill_value)
return self.subtype

def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:

fill_values = [x.fill_value for x in dtypes if isinstance(x, SparseDtype)]
fill_value = fill_values[0]

# np.nan isn't a singleton, so we may end up with multiple
# NaNs here, so we ignore tha all NA case too.
if not (len(set(fill_values)) == 1 or isna(fill_values).all()):
warnings.warn(
"Concatenating sparse arrays with multiple fill "
f"values: '{fill_values}'. Picking the first and "
"converting the rest.",
PerformanceWarning,
stacklevel=6,
)

# TODO also handle non-numpy other dtypes
np_dtypes = [x.subtype if isinstance(x, SparseDtype) else x for x in dtypes]
return SparseDtype(np.find_common_type(np_dtypes, []), fill_value=fill_value)
32 changes: 31 additions & 1 deletion pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np

from pandas._typing import DtypeObj
from pandas.errors import AbstractMethodError

from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
Expand Down Expand Up @@ -33,11 +34,12 @@ class ExtensionDtype:
* type
* name
The following attributes influence the behavior of the dtype in
The following attributes and methods influence the behavior of the dtype in
pandas operations
* _is_numeric
* _is_boolean
* _get_common_dtype
Optionally one can override construct_array_type for construction
with the name of this dtype via the Registry. See
Expand Down Expand Up @@ -322,3 +324,31 @@ def _is_boolean(self) -> bool:
bool
"""
return False

def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
"""
Return the common dtype, if one exists.
Used in `find_common_type` implementation. This is for example used
to determine the resulting dtype in a concat operation.
If no common dtype exists, return None (which gives the other dtypes
the chance to determine a common dtype). If all dtypes in the list
return None, then the common dtype will be "object" dtype (this means
it is never needed to return "object" dtype from this method itself).
Parameters
----------
dtypes : list of dtypes
The dtypes for which to determine a common dtype. This is a list
of np.dtype or ExtensionDtype instances.
Returns
-------
Common dtype (np.dtype or ExtensionDtype) or None
"""
if len(set(dtypes)) == 1:
# only itself
return self
else:
return None
14 changes: 11 additions & 3 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING, Any, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type

import numpy as np

Expand Down Expand Up @@ -1423,7 +1423,7 @@ def maybe_cast_to_datetime(value, dtype, errors: str = "raise"):
return value


def find_common_type(types):
def find_common_type(types: List[DtypeObj]) -> DtypeObj:
"""
Find a common data type among the given dtypes.
Expand All @@ -1450,8 +1450,16 @@ def find_common_type(types):
if all(is_dtype_equal(first, t) for t in types[1:]):
return first

# get unique types (dict.fromkeys is used as order-preserving set())
types = list(dict.fromkeys(types).keys())

if any(isinstance(t, ExtensionDtype) for t in types):
return np.object
for t in types:
if isinstance(t, ExtensionDtype):
res = t._get_common_dtype(types)
if res is not None:
return res
return np.dtype("object")

# take lowest unit
if all(is_datetime64_dtype(t) for t in types):
Expand Down
Loading

0 comments on commit 96101f6

Please sign in to comment.