Skip to content

Commit 0cee41f

Browse files
authored
ENH: Add ea support to get_dummies (#50849)
* ENH: Add ea support to get_dummies * Fix mypy
1 parent 38ad5ce commit 0cee41f

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

doc/source/whatsnew/v2.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ Other enhancements
160160
- Added ``name`` parameter to :meth:`IntervalIndex.from_breaks`, :meth:`IntervalIndex.from_arrays` and :meth:`IntervalIndex.from_tuples` (:issue:`48911`)
161161
- Improve exception message when using :func:`assert_frame_equal` on a :class:`DataFrame` to include the column that is compared (:issue:`50323`)
162162
- Improved error message for :func:`merge_asof` when join-columns were duplicated (:issue:`50102`)
163+
- Added support for extension array dtypes to :func:`get_dummies` (:func:`32430`)
163164
- Added :meth:`Index.infer_objects` analogous to :meth:`Series.infer_objects` (:issue:`50034`)
164165
- Added ``copy`` parameter to :meth:`Series.infer_objects` and :meth:`DataFrame.infer_objects`, passing ``False`` will avoid making copies for series or columns that are already non-object or where no better dtype can be inferred (:issue:`50096`)
165166
- :meth:`DataFrame.plot.hist` now recognizes ``xlabel`` and ``ylabel`` arguments (:issue:`49793`)

pandas/core/reshape/encoding.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
is_integer_dtype,
1717
is_list_like,
1818
is_object_dtype,
19+
pandas_dtype,
1920
)
2021

2122
from pandas.core.arrays import SparseArray
@@ -240,9 +241,9 @@ def _get_dummies_1d(
240241

241242
if dtype is None:
242243
dtype = np.dtype(bool)
243-
dtype = np.dtype(dtype)
244+
_dtype = pandas_dtype(dtype)
244245

245-
if is_object_dtype(dtype):
246+
if is_object_dtype(_dtype):
246247
raise ValueError("dtype=object is not a valid dtype for get_dummies")
247248

248249
def get_empty_frame(data) -> DataFrame:
@@ -317,7 +318,12 @@ def get_empty_frame(data) -> DataFrame:
317318

318319
else:
319320
# take on axis=1 + transpose to ensure ndarray layout is column-major
320-
dummy_mat = np.eye(number_of_cols, dtype=dtype).take(codes, axis=1).T
321+
eye_dtype: NpDtype
322+
if isinstance(_dtype, np.dtype):
323+
eye_dtype = _dtype
324+
else:
325+
eye_dtype = np.bool_
326+
dummy_mat = np.eye(number_of_cols, dtype=eye_dtype).take(codes, axis=1).T
321327

322328
if not dummy_na:
323329
# reset NaN GH4446
@@ -327,7 +333,7 @@ def get_empty_frame(data) -> DataFrame:
327333
# remove first GH12042
328334
dummy_mat = dummy_mat[:, 1:]
329335
dummy_cols = dummy_cols[1:]
330-
return DataFrame(dummy_mat, index=index, columns=dummy_cols)
336+
return DataFrame(dummy_mat, index=index, columns=dummy_cols, dtype=_dtype)
331337

332338

333339
def from_dummies(

pandas/tests/reshape/test_get_dummies.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,3 +657,23 @@ def test_get_dummies_with_string_values(self, values):
657657

658658
with pytest.raises(TypeError, match=msg):
659659
get_dummies(df, columns=values)
660+
661+
def test_get_dummies_ea_dtype_series(self, any_numeric_ea_dtype):
662+
# GH#32430
663+
ser = Series(list("abca"))
664+
result = get_dummies(ser, dtype=any_numeric_ea_dtype)
665+
expected = DataFrame(
666+
{"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]},
667+
dtype=any_numeric_ea_dtype,
668+
)
669+
tm.assert_frame_equal(result, expected)
670+
671+
def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_dtype):
672+
# GH#32430
673+
df = DataFrame({"x": list("abca")})
674+
result = get_dummies(df, dtype=any_numeric_ea_dtype)
675+
expected = DataFrame(
676+
{"x_a": [1, 0, 0, 1], "x_b": [0, 1, 0, 0], "x_c": [0, 0, 1, 0]},
677+
dtype=any_numeric_ea_dtype,
678+
)
679+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)