@@ -424,17 +424,21 @@ def str_extract(arr, pat, flags=0):
424424 Pattern or regular expression
425425 flags : int, default 0 (no flags)
426426 re module flags, e.g. re.IGNORECASE
427+ expand : None or bool, default None
428+ * If None, return Series/Index (one group) or DataFrame/MultiIndex (multiple groups)
429+ * If True, return DataFrame/MultiIndex expanding dimensionality.
430+ * If False, return Series/Index.
427431
428432 Returns
429433 -------
430- extracted groups : Series (one group) or DataFrame (multiple groups)
434+ extracted groups : Series/Index or DataFrame/MultiIndex of objects
431435 Note that dtype of the result is always object, even when no match is
432436 found and the result is a Series or DataFrame containing only NaN
433437 values.
434438
435439 Examples
436440 --------
437- A pattern with one group will return a Series. Non-matches will be NaN.
441+ A pattern with one group returns a Series. Non-matches will be NaN.
438442
439443 >>> Series(['a1', 'b2', 'c3']).str.extract('[ab](\d)')
440444 0 1
@@ -466,11 +470,14 @@ def str_extract(arr, pat, flags=0):
466470 1 b 2
467471 2 NaN NaN
468472
469- """
470- from pandas .core .series import Series
471- from pandas .core .frame import DataFrame
472- from pandas .core .index import Index
473+ Or you can specify ``expand=False`` to return Series.
473474
475+ >>> pd.Series(['a1', 'b2', 'c3']).str.extract('([ab])?(\d)', expand=False)
476+ 0 [a, 1]
477+ 1 [b, 2]
478+ 2 [nan, 3]
479+ Name: [0, 1], dtype: object
480+ """
474481 regex = re .compile (pat , flags = flags )
475482 # just to be safe, check this
476483 if regex .groups == 0 :
@@ -490,18 +497,9 @@ def f(x):
490497 result = np .array ([f (val )[0 ] for val in arr ], dtype = object )
491498 name = _get_single_group_name (regex )
492499 else :
493- if isinstance (arr , Index ):
494- raise ValueError ("only one regex group is supported with Index" )
495- name = None
496500 names = dict (zip (regex .groupindex .values (), regex .groupindex .keys ()))
497- columns = [names .get (1 + i , i ) for i in range (regex .groups )]
498- if arr .empty :
499- result = DataFrame (columns = columns , dtype = object )
500- else :
501- result = DataFrame ([f (val ) for val in arr ],
502- columns = columns ,
503- index = arr .index ,
504- dtype = object )
501+ name = [names .get (1 + i , i ) for i in range (regex .groups )]
502+ result = np .array ([f (val ) for val in arr ], dtype = object )
505503 return result , name
506504
507505
@@ -514,10 +512,13 @@ def str_get_dummies(arr, sep='|'):
514512 ----------
515513 sep : string, default "|"
516514 String to split on.
515+ expand : bool, default True
516+ * If True, return DataFrame/MultiIndex expanding dimensionality.
517+ * If False, return Series/Index.
517518
518519 Returns
519520 -------
520- dummies : DataFrame
521+ dummies : Series/Index or DataFrame/MultiIndex of objects
521522
522523 Examples
523524 --------
@@ -537,14 +538,7 @@ def str_get_dummies(arr, sep='|'):
537538 --------
538539 pandas.get_dummies
539540 """
540- from pandas .core .frame import DataFrame
541541 from pandas .core .index import Index
542-
543- # GH9980, Index.str does not support get_dummies() as it returns a frame
544- if isinstance (arr , Index ):
545- raise TypeError ("get_dummies is not supported for string methods on Index" )
546-
547- # TODO remove this hack?
548542 arr = arr .fillna ('' )
549543 try :
550544 arr = sep + arr + sep
@@ -561,7 +555,7 @@ def str_get_dummies(arr, sep='|'):
561555 for i , t in enumerate (tags ):
562556 pat = sep + t + sep
563557 dummies [:, i ] = lib .map_infer (arr .values , lambda x : pat in x )
564- return DataFrame ( dummies , arr . index , tags )
558+ return dummies , tags
565559
566560
567561def str_join (arr , sep ):
@@ -1081,7 +1075,10 @@ def __iter__(self):
10811075 i += 1
10821076 g = self .get (i )
10831077
1084- def _wrap_result (self , result , use_codes = True , name = None ):
1078+ def _wrap_result (self , result , use_codes = True , name = None , expand = False ):
1079+
1080+ if not isinstance (expand , bool ):
1081+ raise ValueError ("expand must be True or False" )
10851082
10861083 # for category, we do the stuff on the categories, so blow it up
10871084 # to the full series again
@@ -1095,39 +1092,11 @@ def _wrap_result(self, result, use_codes=True, name=None):
10951092 # can be merged to _wrap_result_expand in v0.17
10961093 from pandas .core .series import Series
10971094 from pandas .core .frame import DataFrame
1098- from pandas .core .index import Index
1095+ from pandas .core .index import Index , MultiIndex
10991096
1100- if not hasattr (result , 'ndim' ):
1101- return result
11021097 name = name or getattr (result , 'name' , None ) or self ._orig .name
11031098
1104- if result .ndim == 1 :
1105- if isinstance (self ._orig , Index ):
1106- # if result is a boolean np.array, return the np.array
1107- # instead of wrapping it into a boolean Index (GH 8875)
1108- if is_bool_dtype (result ):
1109- return result
1110- return Index (result , name = name )
1111- return Series (result , index = self ._orig .index , name = name )
1112- else :
1113- assert result .ndim < 3
1114- return DataFrame (result , index = self ._orig .index )
1115-
1116- def _wrap_result_expand (self , result , expand = False ):
1117- if not isinstance (expand , bool ):
1118- raise ValueError ("expand must be True or False" )
1119-
1120- # for category, we do the stuff on the categories, so blow it up
1121- # to the full series again
1122- if self ._is_categorical :
1123- result = take_1d (result , self ._orig .cat .codes )
1124-
1125- from pandas .core .index import Index , MultiIndex
1126- if not hasattr (result , 'ndim' ):
1127- return result
1128-
11291099 if isinstance (self ._orig , Index ):
1130- name = getattr (result , 'name' , None )
11311100 # if result is a boolean np.array, return the np.array
11321101 # instead of wrapping it into a boolean Index (GH 8875)
11331102 if hasattr (result , 'dtype' ) and is_bool_dtype (result ):
@@ -1137,7 +1106,7 @@ def _wrap_result_expand(self, result, expand=False):
11371106 result = list (result )
11381107 return MultiIndex .from_tuples (result , names = name )
11391108 else :
1140- return Index (result , name = name )
1109+ return Index (result , name = name , tupleize_cols = False )
11411110 else :
11421111 index = self ._orig .index
11431112 if expand :
@@ -1148,30 +1117,34 @@ def cons_row(x):
11481117 return [ x ]
11491118 cons = self ._orig ._constructor_expanddim
11501119 data = [cons_row (x ) for x in result ]
1151- return cons (data , index = index )
1120+ return cons (data , index = index , columns = name ,
1121+ dtype = result .dtype )
11521122 else :
1153- name = getattr (result , 'name' , None )
1123+ if result .ndim > 1 :
1124+ result = list (result )
11541125 cons = self ._orig ._constructor
11551126 return cons (result , name = name , index = index )
11561127
11571128 @copy (str_cat )
11581129 def cat (self , others = None , sep = None , na_rep = None ):
11591130 data = self ._orig if self ._is_categorical else self ._data
11601131 result = str_cat (data , others = others , sep = sep , na_rep = na_rep )
1132+ if not hasattr (result , 'ndim' ):
1133+ # str_cat may results in np.nan or str
1134+ return result
11611135 return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
11621136
1163-
11641137 @deprecate_kwarg ('return_type' , 'expand' ,
11651138 mapping = {'series' : False , 'frame' : True })
11661139 @copy (str_split )
11671140 def split (self , pat = None , n = - 1 , expand = False ):
11681141 result = str_split (self ._data , pat , n = n )
1169- return self ._wrap_result_expand (result , expand = expand )
1142+ return self ._wrap_result (result , expand = expand )
11701143
11711144 @copy (str_rsplit )
11721145 def rsplit (self , pat = None , n = - 1 , expand = False ):
11731146 result = str_rsplit (self ._data , pat , n = n )
1174- return self ._wrap_result_expand (result , expand = expand )
1147+ return self ._wrap_result (result , expand = expand )
11751148
11761149 _shared_docs ['str_partition' ] = ("""
11771150 Split the string at the %(side)s occurrence of `sep`, and return 3 elements
@@ -1222,15 +1195,15 @@ def rsplit(self, pat=None, n=-1, expand=False):
12221195 def partition (self , pat = ' ' , expand = True ):
12231196 f = lambda x : x .partition (pat )
12241197 result = _na_map (f , self ._data )
1225- return self ._wrap_result_expand (result , expand = expand )
1198+ return self ._wrap_result (result , expand = expand )
12261199
12271200 @Appender (_shared_docs ['str_partition' ] % {'side' : 'last' ,
12281201 'return' : '3 elements containing two empty strings, followed by the string itself' ,
12291202 'also' : 'partition : Split the string at the first occurrence of `sep`' })
12301203 def rpartition (self , pat = ' ' , expand = True ):
12311204 f = lambda x : x .rpartition (pat )
12321205 result = _na_map (f , self ._data )
1233- return self ._wrap_result_expand (result , expand = expand )
1206+ return self ._wrap_result (result , expand = expand )
12341207
12351208 @copy (str_get )
12361209 def get (self , i ):
@@ -1371,12 +1344,13 @@ def wrap(self, width, **kwargs):
13711344 return self ._wrap_result (result )
13721345
13731346 @copy (str_get_dummies )
1374- def get_dummies (self , sep = '|' ):
1347+ def get_dummies (self , sep = '|' , expand = True ):
13751348 # we need to cast to Series of strings as only that has all
13761349 # methods available for making the dummies...
13771350 data = self ._orig .astype (str ) if self ._is_categorical else self ._data
1378- result = str_get_dummies (data , sep )
1379- return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
1351+ result , name = str_get_dummies (data , sep )
1352+ return self ._wrap_result (result , use_codes = (not self ._is_categorical ),
1353+ name = name , expand = expand )
13801354
13811355 @copy (str_translate )
13821356 def translate (self , table , deletechars = None ):
@@ -1389,9 +1363,18 @@ def translate(self, table, deletechars=None):
13891363 findall = _pat_wrapper (str_findall , flags = True )
13901364
13911365 @copy (str_extract )
1392- def extract (self , pat , flags = 0 ):
1393- result , name = str_extract (self ._data , pat , flags = flags )
1394- return self ._wrap_result (result , name = name )
1366+ def extract (self , pat , flags = 0 , expand = None ):
1367+ result , name = str_extract (self ._orig , pat , flags = flags )
1368+ if expand is None and hasattr (result , 'ndim' ):
1369+ # to be compat with previous behavior
1370+ if len (result ) == 0 :
1371+ # for empty input
1372+ expand = True if isinstance (name , list ) else False
1373+ elif result .ndim > 1 :
1374+ expand = True
1375+ else :
1376+ expand = False
1377+ return self ._wrap_result (result , name = name , use_codes = False , expand = expand )
13951378
13961379 _shared_docs ['find' ] = ("""
13971380 Return %(side)s indexes in each strings in the Series/Index
0 commit comments