Skip to content

ENH: add expand kw to str.get_dummies #10103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 53 additions & 70 deletions pandas/core/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,17 +424,21 @@ def str_extract(arr, pat, flags=0):
Pattern or regular expression
flags : int, default 0 (no flags)
re module flags, e.g. re.IGNORECASE
expand : None or bool, default None
* If None, return Series/Index (one group) or DataFrame/MultiIndex (multiple groups)
* If True, return DataFrame/MultiIndex expanding dimensionality.
* If False, return Series/Index.

Returns
-------
extracted groups : Series (one group) or DataFrame (multiple groups)
extracted groups : Series/Index or DataFrame/MultiIndex of objects
Note that dtype of the result is always object, even when no match is
found and the result is a Series or DataFrame containing only NaN
values.

Examples
--------
A pattern with one group will return a Series. Non-matches will be NaN.
A pattern with one group returns a Series. Non-matches will be NaN.

>>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)')
0 1
Expand Down Expand Up @@ -466,11 +470,14 @@ def str_extract(arr, pat, flags=0):
1 b 2
2 NaN NaN

"""
from pandas.core.series import Series
from pandas.core.frame import DataFrame
from pandas.core.index import Index
Or you can specify ``expand=False`` to return Series.

>>> pd.Series(['a1', 'b2', 'c3']).str.extract('([ab])?(\d)', expand=False)
0 [a, 1]
1 [b, 2]
2 [nan, 3]
Name: [0, 1], dtype: object
"""
regex = re.compile(pat, flags=flags)
# just to be safe, check this
if regex.groups == 0:
Expand All @@ -490,18 +497,9 @@ def f(x):
result = np.array([f(val)[0] for val in arr], dtype=object)
name = _get_single_group_name(regex)
else:
if isinstance(arr, Index):
raise ValueError("only one regex group is supported with Index")
name = None
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
columns = [names.get(1 + i, i) for i in range(regex.groups)]
if arr.empty:
result = DataFrame(columns=columns, dtype=object)
else:
result = DataFrame([f(val) for val in arr],
columns=columns,
index=arr.index,
dtype=object)
name = [names.get(1 + i, i) for i in range(regex.groups)]
result = np.array([f(val) for val in arr], dtype=object)
return result, name


Expand All @@ -514,10 +512,13 @@ def str_get_dummies(arr, sep='|'):
----------
sep : string, default "|"
String to split on.
expand : bool, default True
* If True, return DataFrame/MultiIndex expanding dimensionality.
* If False, return Series/Index.

Returns
-------
dummies : DataFrame
dummies : Series/Index or DataFrame/MultiIndex of objects

Examples
--------
Expand All @@ -537,14 +538,7 @@ def str_get_dummies(arr, sep='|'):
--------
pandas.get_dummies
"""
from pandas.core.frame import DataFrame
from pandas.core.index import Index

# GH9980, Index.str does not support get_dummies() as it returns a frame
if isinstance(arr, Index):
raise TypeError("get_dummies is not supported for string methods on Index")

# TODO remove this hack?
arr = arr.fillna('')
try:
arr = sep + arr + sep
Expand All @@ -561,7 +555,7 @@ def str_get_dummies(arr, sep='|'):
for i, t in enumerate(tags):
pat = sep + t + sep
dummies[:, i] = lib.map_infer(arr.values, lambda x: pat in x)
return DataFrame(dummies, arr.index, tags)
return dummies, tags


def str_join(arr, sep):
Expand Down Expand Up @@ -1081,7 +1075,10 @@ def __iter__(self):
i += 1
g = self.get(i)

def _wrap_result(self, result, use_codes=True, name=None):
def _wrap_result(self, result, use_codes=True, name=None, expand=False):

if not isinstance(expand, bool):
raise ValueError("expand must be True or False")

# for category, we do the stuff on the categories, so blow it up
# to the full series again
Expand All @@ -1095,39 +1092,11 @@ def _wrap_result(self, result, use_codes=True, name=None):
# can be merged to _wrap_result_expand in v0.17
from pandas.core.series import Series
from pandas.core.frame import DataFrame
from pandas.core.index import Index
from pandas.core.index import Index, MultiIndex

if not hasattr(result, 'ndim'):
return result
name = name or getattr(result, 'name', None) or self._orig.name

if result.ndim == 1:
if isinstance(self._orig, Index):
# if result is a boolean np.array, return the np.array
# instead of wrapping it into a boolean Index (GH 8875)
if is_bool_dtype(result):
return result
return Index(result, name=name)
return Series(result, index=self._orig.index, name=name)
else:
assert result.ndim < 3
return DataFrame(result, index=self._orig.index)

def _wrap_result_expand(self, result, expand=False):
if not isinstance(expand, bool):
raise ValueError("expand must be True or False")

# for category, we do the stuff on the categories, so blow it up
# to the full series again
if self._is_categorical:
result = take_1d(result, self._orig.cat.codes)

from pandas.core.index import Index, MultiIndex
if not hasattr(result, 'ndim'):
return result

if isinstance(self._orig, Index):
name = getattr(result, 'name', None)
# if result is a boolean np.array, return the np.array
# instead of wrapping it into a boolean Index (GH 8875)
if hasattr(result, 'dtype') and is_bool_dtype(result):
Expand All @@ -1137,7 +1106,7 @@ def _wrap_result_expand(self, result, expand=False):
result = list(result)
return MultiIndex.from_tuples(result, names=name)
else:
return Index(result, name=name)
return Index(result, name=name, tupleize_cols=False)
else:
index = self._orig.index
if expand:
Expand All @@ -1148,30 +1117,34 @@ def cons_row(x):
return [ x ]
cons = self._orig._constructor_expanddim
data = [cons_row(x) for x in result]
return cons(data, index=index)
return cons(data, index=index, columns=name,
dtype=result.dtype)
else:
name = getattr(result, 'name', None)
if result.ndim > 1:
result = list(result)
cons = self._orig._constructor
return cons(result, name=name, index=index)

@copy(str_cat)
def cat(self, others=None, sep=None, na_rep=None):
data = self._orig if self._is_categorical else self._data
result = str_cat(data, others=others, sep=sep, na_rep=na_rep)
if not hasattr(result, 'ndim'):
# str_cat may results in np.nan or str
return result
return self._wrap_result(result, use_codes=(not self._is_categorical))


@deprecate_kwarg('return_type', 'expand',
mapping={'series': False, 'frame': True})
@copy(str_split)
def split(self, pat=None, n=-1, expand=False):
result = str_split(self._data, pat, n=n)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@copy(str_rsplit)
def rsplit(self, pat=None, n=-1, expand=False):
result = str_rsplit(self._data, pat, n=n)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

_shared_docs['str_partition'] = ("""
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
Expand Down Expand Up @@ -1222,15 +1195,15 @@ def rsplit(self, pat=None, n=-1, expand=False):
def partition(self, pat=' ', expand=True):
f = lambda x: x.partition(pat)
result = _na_map(f, self._data)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@Appender(_shared_docs['str_partition'] % {'side': 'last',
'return': '3 elements containing two empty strings, followed by the string itself',
'also': 'partition : Split the string at the first occurrence of `sep`'})
def rpartition(self, pat=' ', expand=True):
f = lambda x: x.rpartition(pat)
result = _na_map(f, self._data)
return self._wrap_result_expand(result, expand=expand)
return self._wrap_result(result, expand=expand)

@copy(str_get)
def get(self, i):
Expand Down Expand Up @@ -1371,12 +1344,13 @@ def wrap(self, width, **kwargs):
return self._wrap_result(result)

@copy(str_get_dummies)
def get_dummies(self, sep='|'):
def get_dummies(self, sep='|', expand=True):
# we need to cast to Series of strings as only that has all
# methods available for making the dummies...
data = self._orig.astype(str) if self._is_categorical else self._data
result = str_get_dummies(data, sep)
return self._wrap_result(result, use_codes=(not self._is_categorical))
result, name = str_get_dummies(data, sep)
return self._wrap_result(result, use_codes=(not self._is_categorical),
name=name, expand=expand)

@copy(str_translate)
def translate(self, table, deletechars=None):
Expand All @@ -1389,9 +1363,18 @@ def translate(self, table, deletechars=None):
findall = _pat_wrapper(str_findall, flags=True)

@copy(str_extract)
def extract(self, pat, flags=0):
result, name = str_extract(self._data, pat, flags=flags)
return self._wrap_result(result, name=name)
def extract(self, pat, flags=0, expand=None):
result, name = str_extract(self._orig, pat, flags=flags)
if expand is None and hasattr(result, 'ndim'):
# to be compat with previous behavior
if len(result) == 0:
# for empty input
expand = True if isinstance(name, list) else False
elif result.ndim > 1:
expand = True
else:
expand = False
return self._wrap_result(result, name=name, use_codes=False, expand=expand)

_shared_docs['find'] = ("""
Return %(side)s indexes in each strings in the Series/Index
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3714,6 +3714,7 @@ def test_str_accessor_api_for_categorical(self):


for func, args, kwargs in func_defs:
print(func, args, kwargs, c)
res = getattr(c.str, func)(*args, **kwargs)
exp = getattr(s.str, func)(*args, **kwargs)

Expand Down
Loading