Skip to content

Commit f30f63c

Browse files
committed
ENH: add expand kw to str.extract and str.get_dummies
1 parent 9b04bd0 commit f30f63c

File tree

2 files changed

+212
-75
lines changed

2 files changed

+212
-75
lines changed

pandas/core/strings.py

+54-63
Original file line numberDiff line numberDiff line change
@@ -421,17 +421,21 @@ def str_extract(arr, pat, flags=0):
421421
Pattern or regular expression
422422
flags : int, default 0 (no flags)
423423
re module flags, e.g. re.IGNORECASE
424+
expand : None or bool, default None
425+
* If None, return Series/Index (one group) or DataFrame/MultiIndex (multiple groups)
426+
* If True, return DataFrame/MultiIndex expanding dimensionality.
427+
* If False, return Series/Index.
424428
425429
Returns
426430
-------
427-
extracted groups : Series (one group) or DataFrame (multiple groups)
431+
extracted groups : Series/Index or DataFrame/MultiIndex of objects
428432
Note that dtype of the result is always object, even when no match is
429433
found and the result is a Series or DataFrame containing only NaN
430434
values.
431435
432436
Examples
433437
--------
434-
A pattern with one group will return a Series. Non-matches will be NaN.
438+
A pattern with one group returns a Series. Non-matches will be NaN.
435439
436440
>>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)')
437441
0 1
@@ -463,11 +467,14 @@ def str_extract(arr, pat, flags=0):
463467
1 b 2
464468
2 NaN NaN
465469
466-
"""
467-
from pandas.core.series import Series
468-
from pandas.core.frame import DataFrame
469-
from pandas.core.index import Index
470+
Or you can specify ``expand=False`` to return Series.
470471
472+
>>> pd.Series(['a1', 'b2', 'c3']).str.extract('([ab])?(\d)', expand=False)
473+
0 [a, 1]
474+
1 [b, 2]
475+
2 [nan, 3]
476+
Name: [0, 1], dtype: object
477+
"""
471478
regex = re.compile(pat, flags=flags)
472479
# just to be safe, check this
473480
if regex.groups == 0:
@@ -487,18 +494,9 @@ def f(x):
487494
result = np.array([f(val)[0] for val in arr], dtype=object)
488495
name = _get_single_group_name(regex)
489496
else:
490-
if isinstance(arr, Index):
491-
raise ValueError("only one regex group is supported with Index")
492-
name = None
493497
names = dict(zip(regex.groupindex.values(), regex.groupindex.keys()))
494-
columns = [names.get(1 + i, i) for i in range(regex.groups)]
495-
if arr.empty:
496-
result = DataFrame(columns=columns, dtype=object)
497-
else:
498-
result = DataFrame([f(val) for val in arr],
499-
columns=columns,
500-
index=arr.index,
501-
dtype=object)
498+
name = [names.get(1 + i, i) for i in range(regex.groups)]
499+
result = np.array([f(val) for val in arr], dtype=object)
502500
return result, name
503501

504502

@@ -511,10 +509,13 @@ def str_get_dummies(arr, sep='|'):
511509
----------
512510
sep : string, default "|"
513511
String to split on.
512+
expand : bool, default True
513+
* If True, return DataFrame/MultiIndex expanding dimensionality.
514+
* If False, return Series/Index.
514515
515516
Returns
516517
-------
517-
dummies : DataFrame
518+
dummies : Series/Index or DataFrame/MultiIndex of objects
518519
519520
Examples
520521
--------
@@ -534,15 +535,15 @@ def str_get_dummies(arr, sep='|'):
534535
--------
535536
pandas.get_dummies
536537
"""
537-
from pandas.core.frame import DataFrame
538538
from pandas.core.index import Index
539-
540-
# GH9980, Index.str does not support get_dummies() as it returns a frame
539+
# TODO: Add fillna GH 10089
541540
if isinstance(arr, Index):
542-
raise TypeError("get_dummies is not supported for string methods on Index")
543-
544-
# TODO remove this hack?
545-
arr = arr.fillna('')
541+
# temp hack
542+
values = arr.values
543+
values[isnull(values)] = ''
544+
arr = Index(values)
545+
else:
546+
arr = arr.fillna('')
546547
try:
547548
arr = sep + arr + sep
548549
except TypeError:
@@ -558,7 +559,7 @@ def str_get_dummies(arr, sep='|'):
558559
for i, t in enumerate(tags):
559560
pat = sep + t + sep
560561
dummies[:, i] = lib.map_infer(arr.values, lambda x: pat in x)
561-
return DataFrame(dummies, arr.index, tags)
562+
return dummies, tags
562563

563564

564565
def str_join(arr, sep):
@@ -1043,40 +1044,19 @@ def __iter__(self):
10431044
i += 1
10441045
g = self.get(i)
10451046

1046-
def _wrap_result(self, result, **kwargs):
1047-
1048-
# leave as it is to keep extract and get_dummies results
1049-
# can be merged to _wrap_result_expand in v0.17
1050-
from pandas.core.series import Series
1051-
from pandas.core.frame import DataFrame
1052-
from pandas.core.index import Index
1053-
1054-
if not hasattr(result, 'ndim'):
1055-
return result
1056-
name = kwargs.get('name') or getattr(result, 'name', None) or self.series.name
1057-
1058-
if result.ndim == 1:
1059-
if isinstance(self.series, Index):
1060-
# if result is a boolean np.array, return the np.array
1061-
# instead of wrapping it into a boolean Index (GH 8875)
1062-
if is_bool_dtype(result):
1063-
return result
1064-
return Index(result, name=name)
1065-
return Series(result, index=self.series.index, name=name)
1066-
else:
1067-
assert result.ndim < 3
1068-
return DataFrame(result, index=self.series.index)
1047+
def _wrap_result(self, result, expand=False, name=None):
1048+
from pandas.core.index import Index, MultiIndex
10691049

1070-
def _wrap_result_expand(self, result, expand=False):
10711050
if not isinstance(expand, bool):
10721051
raise ValueError("expand must be True or False")
10731052

1074-
from pandas.core.index import Index, MultiIndex
1053+
if name is None:
1054+
name = getattr(result, 'name', None) or self.series.name
1055+
10751056
if not hasattr(result, 'ndim'):
10761057
return result
10771058

10781059
if isinstance(self.series, Index):
1079-
name = getattr(result, 'name', None)
10801060
# if result is a boolean np.array, return the np.array
10811061
# instead of wrapping it into a boolean Index (GH 8875)
10821062
if hasattr(result, 'dtype') and is_bool_dtype(result):
@@ -1092,10 +1072,12 @@ def _wrap_result_expand(self, result, expand=False):
10921072
if expand:
10931073
cons_row = self.series._constructor
10941074
cons = self.series._constructor_expanddim
1095-
data = [cons_row(x) for x in result]
1096-
return cons(data, index=index)
1075+
data = [cons_row(x, index=name) for x in result]
1076+
return cons(data, index=index, columns=name,
1077+
dtype=result.dtype)
10971078
else:
1098-
name = getattr(result, 'name', None)
1079+
if result.ndim > 1:
1080+
result = list(result)
10991081
cons = self.series._constructor
11001082
return cons(result, name=name, index=index)
11011083

@@ -1109,7 +1091,7 @@ def cat(self, others=None, sep=None, na_rep=None):
11091091
@copy(str_split)
11101092
def split(self, pat=None, n=-1, expand=False):
11111093
result = str_split(self.series, pat, n=n)
1112-
return self._wrap_result_expand(result, expand=expand)
1094+
return self._wrap_result(result, expand=expand)
11131095

11141096
_shared_docs['str_partition'] = ("""
11151097
Split the string at the %(side)s occurrence of `sep`, and return 3 elements
@@ -1160,15 +1142,15 @@ def split(self, pat=None, n=-1, expand=False):
11601142
def partition(self, pat=' ', expand=True):
11611143
f = lambda x: x.partition(pat)
11621144
result = _na_map(f, self.series)
1163-
return self._wrap_result_expand(result, expand=expand)
1145+
return self._wrap_result(result, expand=expand)
11641146

11651147
@Appender(_shared_docs['str_partition'] % {'side': 'last',
11661148
'return': '3 elements containing two empty strings, followed by the string itself',
11671149
'also': 'partition : Split the string at the first occurrence of `sep`'})
11681150
def rpartition(self, pat=' ', expand=True):
11691151
f = lambda x: x.rpartition(pat)
11701152
result = _na_map(f, self.series)
1171-
return self._wrap_result_expand(result, expand=expand)
1153+
return self._wrap_result(result, expand=expand)
11721154

11731155
@copy(str_get)
11741156
def get(self, i):
@@ -1309,9 +1291,9 @@ def wrap(self, width, **kwargs):
13091291
return self._wrap_result(result)
13101292

13111293
@copy(str_get_dummies)
1312-
def get_dummies(self, sep='|'):
1313-
result = str_get_dummies(self.series, sep)
1314-
return self._wrap_result(result)
1294+
def get_dummies(self, sep='|', expand=True):
1295+
result, name = str_get_dummies(self.series, sep)
1296+
return self._wrap_result(result, name=name, expand=expand)
13151297

13161298
@copy(str_translate)
13171299
def translate(self, table, deletechars=None):
@@ -1324,9 +1306,18 @@ def translate(self, table, deletechars=None):
13241306
findall = _pat_wrapper(str_findall, flags=True)
13251307

13261308
@copy(str_extract)
1327-
def extract(self, pat, flags=0):
1309+
def extract(self, pat, flags=0, expand=None):
13281310
result, name = str_extract(self.series, pat, flags=flags)
1329-
return self._wrap_result(result, name=name)
1311+
if expand is None and hasattr(result, 'ndim'):
1312+
# to be compat with previous behavior
1313+
if len(result) == 0:
1314+
# for empty input
1315+
expand = True if isinstance(name, list) else False
1316+
elif result.ndim > 1:
1317+
expand = True
1318+
else:
1319+
expand = False
1320+
return self._wrap_result(result, name=name, expand=expand)
13301321

13311322
_shared_docs['find'] = ("""
13321323
Return %(side)s indexes in each strings in the Series/Index

0 commit comments

Comments
 (0)