Skip to content

Commit

Permalink
Backport PR pandas-dev#56672 on branch 2.2.x (BUG: dictionary type as…
Browse files Browse the repository at this point in the history
…type categorical using dictionary as categories) (pandas-dev#56723)

Backport PR pandas-dev#56672: BUG: dictionary type astype categorical using dictionary as categories

Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and phofl authored Jan 3, 2024
1 parent 1c3c988 commit 0cd02c5
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ Categorical
^^^^^^^^^^^
- :meth:`Categorical.isin` raising ``InvalidIndexError`` for categorical containing overlapping :class:`Interval` values (:issue:`34974`)
- Bug in :meth:`CategoricalDtype.__eq__` returning ``False`` for unordered categorical data with mixed types (:issue:`55468`)
- Bug when casting ``pa.dictionary`` to :class:`CategoricalDtype` using a ``pa.DictionaryArray`` as categories (:issue:`56672`)

Datetimelike
^^^^^^^^^^^^
Expand Down
46 changes: 28 additions & 18 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
pandas_dtype,
)
from pandas.core.dtypes.dtypes import (
ArrowDtype,
CategoricalDtype,
CategoricalDtypeType,
ExtensionDtype,
)
from pandas.core.dtypes.generic import (
Expand Down Expand Up @@ -443,24 +445,32 @@ def __init__(
values = arr

if dtype.categories is None:
if not isinstance(values, ABCIndex):
# in particular RangeIndex xref test_index_equal_range_categories
values = sanitize_array(values, None)
try:
codes, categories = factorize(values, sort=True)
except TypeError as err:
codes, categories = factorize(values, sort=False)
if dtype.ordered:
# raise, as we don't have a sortable data structure and so
# the user should give us one by specifying categories
raise TypeError(
"'values' is not ordered, please "
"explicitly specify the categories order "
"by passing in a categories argument."
) from err

# we're inferring from values
dtype = CategoricalDtype(categories, dtype.ordered)
if isinstance(values.dtype, ArrowDtype) and issubclass(
values.dtype.type, CategoricalDtypeType
):
arr = values._pa_array.combine_chunks()
categories = arr.dictionary.to_pandas(types_mapper=ArrowDtype)
codes = arr.indices.to_numpy()
dtype = CategoricalDtype(categories, values.dtype.pyarrow_dtype.ordered)
else:
if not isinstance(values, ABCIndex):
# in particular RangeIndex xref test_index_equal_range_categories
values = sanitize_array(values, None)
try:
codes, categories = factorize(values, sort=True)
except TypeError as err:
codes, categories = factorize(values, sort=False)
if dtype.ordered:
# raise, as we don't have a sortable data structure and so
# the user should give us one by specifying categories
raise TypeError(
"'values' is not ordered, please "
"explicitly specify the categories order "
"by passing in a categories argument."
) from err

# we're inferring from values
dtype = CategoricalDtype(categories, dtype.ordered)

elif isinstance(values.dtype, CategoricalDtype):
old_codes = extract_array(values)._codes
Expand Down
16 changes: 16 additions & 0 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3234,6 +3234,22 @@ def test_factorize_chunked_dictionary():
tm.assert_index_equal(res_uniques, exp_uniques)


def test_dictionary_astype_categorical():
# GH#56672
arrs = [
pa.array(np.array(["a", "x", "c", "a"])).dictionary_encode(),
pa.array(np.array(["a", "d", "c"])).dictionary_encode(),
]
ser = pd.Series(ArrowExtensionArray(pa.chunked_array(arrs)))
result = ser.astype("category")
categories = pd.Index(["a", "x", "c", "d"], dtype=ArrowDtype(pa.string()))
expected = pd.Series(
["a", "x", "c", "a", "a", "d", "c"],
dtype=pd.CategoricalDtype(categories=categories),
)
tm.assert_series_equal(result, expected)


def test_arrow_floordiv():
# GH 55561
a = pd.Series([-7], dtype="int64[pyarrow]")
Expand Down

0 comments on commit 0cd02c5

Please sign in to comment.