22from  functools  import  wraps 
33import  re 
44import  textwrap 
5- from  typing  import  Dict , List 
5+ from  typing  import  TYPE_CHECKING ,  Any ,  Callable ,  Dict , List 
66import  warnings 
77
88import  numpy  as  np 
1515    ensure_object ,
1616    is_bool_dtype ,
1717    is_categorical_dtype ,
18+     is_extension_array_dtype ,
1819    is_integer ,
20+     is_integer_dtype ,
1921    is_list_like ,
22+     is_object_dtype ,
2023    is_re ,
2124    is_scalar ,
25+     is_string_dtype ,
2226)
2327from  pandas .core .dtypes .generic  import  (
2428    ABCDataFrame ,
2832)
2933from  pandas .core .dtypes .missing  import  isna 
3034
35+ from  pandas ._typing  import  ArrayLike , Dtype 
3136from  pandas .core .algorithms  import  take_1d 
3237from  pandas .core .base  import  NoNewAttributesMixin 
3338import  pandas .core .common  as  com 
39+ from  pandas .core .construction  import  extract_array 
40+ 
41+ if  TYPE_CHECKING :
42+     from  pandas .arrays  import  StringArray 
3443
3544_cpython_optimized_encoders  =  (
3645    "utf-8" ,
@@ -109,10 +118,79 @@ def cat_safe(list_of_columns: List, sep: str):
109118
110119def  _na_map (f , arr , na_result = np .nan , dtype = object ):
111120    # should really _check_ for NA 
112-     return  _map (f , arr , na_mask = True , na_value = na_result , dtype = dtype )
121+     if  is_extension_array_dtype (arr .dtype ):
122+         # just StringDtype 
123+         arr  =  extract_array (arr )
124+         return  _map_stringarray (f , arr , na_value = na_result , dtype = dtype )
125+     return  _map_object (f , arr , na_mask = True , na_value = na_result , dtype = dtype )
126+ 
127+ 
128+ def  _map_stringarray (
129+     func : Callable [[str ], Any ], arr : "StringArray" , na_value : Any , dtype : Dtype 
130+ ) ->  ArrayLike :
131+     """ 
132+     Map a callable over valid elements of a StringArrray. 
133+ 
134+     Parameters 
135+     ---------- 
136+     func : Callable[[str], Any] 
137+         Apply to each valid element. 
138+     arr : StringArray 
139+     na_value : Any 
140+         The value to use for missing values. By default, this is 
141+         the original value (NA). 
142+     dtype : Dtype 
143+         The result dtype to use. Specifying this aviods an intermediate 
144+         object-dtype allocation. 
145+ 
146+     Returns 
147+     ------- 
148+     ArrayLike 
149+         An ExtensionArray for integer or string dtypes, otherwise 
150+         an ndarray. 
151+ 
152+     """ 
153+     from  pandas .arrays  import  IntegerArray , StringArray 
154+ 
155+     mask  =  isna (arr )
156+ 
157+     assert  isinstance (arr , StringArray )
158+     arr  =  np .asarray (arr )
159+ 
160+     if  is_integer_dtype (dtype ):
161+         na_value_is_na  =  isna (na_value )
162+         if  na_value_is_na :
163+             na_value  =  1 
164+         result  =  lib .map_infer_mask (
165+             arr ,
166+             func ,
167+             mask .view ("uint8" ),
168+             convert = False ,
169+             na_value = na_value ,
170+             dtype = np .dtype ("int64" ),
171+         )
172+ 
173+         if  not  na_value_is_na :
174+             mask [:] =  False 
175+ 
176+         return  IntegerArray (result , mask )
177+ 
178+     elif  is_string_dtype (dtype ) and  not  is_object_dtype (dtype ):
179+         # i.e. StringDtype 
180+         result  =  lib .map_infer_mask (
181+             arr , func , mask .view ("uint8" ), convert = False , na_value = na_value 
182+         )
183+         return  StringArray (result )
184+     # TODO: BooleanArray 
185+     else :
186+         # This is when the result type is object. We reach this when 
187+         # -> We know the result type is truly object (e.g. .encode returns bytes 
188+         #    or .findall returns a list). 
189+         # -> We don't know the result type. E.g. `.get` can return anything. 
190+         return  lib .map_infer_mask (arr , func , mask .view ("uint8" ))
113191
114192
115- def  _map (f , arr , na_mask = False , na_value = np .nan , dtype = object ):
193+ def  _map_object (f , arr , na_mask = False , na_value = np .nan , dtype = object ):
116194    if  not  len (arr ):
117195        return  np .ndarray (0 , dtype = dtype )
118196
@@ -143,7 +221,7 @@ def g(x):
143221                except  (TypeError , AttributeError ):
144222                    return  na_value 
145223
146-             return  _map (g , arr , dtype = dtype )
224+             return  _map_object (g , arr , dtype = dtype )
147225        if  na_value  is  not   np .nan :
148226            np .putmask (result , mask , na_value )
149227            if  result .dtype  ==  object :
@@ -634,7 +712,7 @@ def str_replace(arr, pat, repl, n=-1, case=None, flags=0, regex=True):
634712            raise  ValueError ("Cannot use a callable replacement when regex=False" )
635713        f  =  lambda  x : x .replace (pat , repl , n )
636714
637-     return  _na_map (f , arr )
715+     return  _na_map (f , arr ,  dtype = str )
638716
639717
640718def  str_repeat (arr , repeats ):
@@ -685,7 +763,7 @@ def scalar_rep(x):
685763            except  TypeError :
686764                return  str .__mul__ (x , repeats )
687765
688-         return  _na_map (scalar_rep , arr )
766+         return  _na_map (scalar_rep , arr ,  dtype = str )
689767    else :
690768
691769        def  rep (x , r ):
@@ -1150,7 +1228,7 @@ def str_join(arr, sep):
11501228    4                    NaN 
11511229    dtype: object 
11521230    """ 
1153-     return  _na_map (sep .join , arr )
1231+     return  _na_map (sep .join , arr ,  dtype = str )
11541232
11551233
11561234def  str_findall (arr , pat , flags = 0 ):
@@ -1381,7 +1459,7 @@ def str_pad(arr, width, side="left", fillchar=" "):
13811459    else :  # pragma: no cover 
13821460        raise  ValueError ("Invalid side" )
13831461
1384-     return  _na_map (f , arr )
1462+     return  _na_map (f , arr ,  dtype = str )
13851463
13861464
13871465def  str_split (arr , pat = None , n = None ):
@@ -1487,7 +1565,7 @@ def str_slice(arr, start=None, stop=None, step=None):
14871565    """ 
14881566    obj  =  slice (start , stop , step )
14891567    f  =  lambda  x : x [obj ]
1490-     return  _na_map (f , arr )
1568+     return  _na_map (f , arr ,  dtype = str )
14911569
14921570
14931571def  str_slice_replace (arr , start = None , stop = None , repl = None ):
@@ -1578,7 +1656,7 @@ def f(x):
15781656            y  +=  x [local_stop :]
15791657        return  y 
15801658
1581-     return  _na_map (f , arr )
1659+     return  _na_map (f , arr ,  dtype = str )
15821660
15831661
15841662def  str_strip (arr , to_strip = None , side = "both" ):
@@ -1603,7 +1681,7 @@ def str_strip(arr, to_strip=None, side="both"):
16031681        f  =  lambda  x : x .rstrip (to_strip )
16041682    else :  # pragma: no cover 
16051683        raise  ValueError ("Invalid side" )
1606-     return  _na_map (f , arr )
1684+     return  _na_map (f , arr ,  dtype = str )
16071685
16081686
16091687def  str_wrap (arr , width , ** kwargs ):
@@ -1667,7 +1745,7 @@ def str_wrap(arr, width, **kwargs):
16671745
16681746    tw  =  textwrap .TextWrapper (** kwargs )
16691747
1670-     return  _na_map (lambda  s : "\n " .join (tw .wrap (s )), arr )
1748+     return  _na_map (lambda  s : "\n " .join (tw .wrap (s )), arr ,  dtype = str )
16711749
16721750
16731751def  str_translate (arr , table ):
@@ -1687,7 +1765,7 @@ def str_translate(arr, table):
16871765    ------- 
16881766    Series or Index 
16891767    """ 
1690-     return  _na_map (lambda  x : x .translate (table ), arr )
1768+     return  _na_map (lambda  x : x .translate (table ), arr ,  dtype = str )
16911769
16921770
16931771def  str_get (arr , i ):
@@ -3025,7 +3103,7 @@ def normalize(self, form):
30253103        import  unicodedata 
30263104
30273105        f  =  lambda  x : unicodedata .normalize (form , x )
3028-         result  =  _na_map (f , self ._parent )
3106+         result  =  _na_map (f , self ._parent ,  dtype = str )
30293107        return  self ._wrap_result (result )
30303108
30313109    _shared_docs [
@@ -3223,31 +3301,37 @@ def rindex(self, sub, start=0, end=None):
32233301        lambda  x : x .lower (),
32243302        name = "lower" ,
32253303        docstring = _shared_docs ["casemethods" ] %  _doc_args ["lower" ],
3304+         dtype = str ,
32263305    )
32273306    upper  =  _noarg_wrapper (
32283307        lambda  x : x .upper (),
32293308        name = "upper" ,
32303309        docstring = _shared_docs ["casemethods" ] %  _doc_args ["upper" ],
3310+         dtype = str ,
32313311    )
32323312    title  =  _noarg_wrapper (
32333313        lambda  x : x .title (),
32343314        name = "title" ,
32353315        docstring = _shared_docs ["casemethods" ] %  _doc_args ["title" ],
3316+         dtype = str ,
32363317    )
32373318    capitalize  =  _noarg_wrapper (
32383319        lambda  x : x .capitalize (),
32393320        name = "capitalize" ,
32403321        docstring = _shared_docs ["casemethods" ] %  _doc_args ["capitalize" ],
3322+         dtype = str ,
32413323    )
32423324    swapcase  =  _noarg_wrapper (
32433325        lambda  x : x .swapcase (),
32443326        name = "swapcase" ,
32453327        docstring = _shared_docs ["casemethods" ] %  _doc_args ["swapcase" ],
3328+         dtype = str ,
32463329    )
32473330    casefold  =  _noarg_wrapper (
32483331        lambda  x : x .casefold (),
32493332        name = "casefold" ,
32503333        docstring = _shared_docs ["casemethods" ] %  _doc_args ["casefold" ],
3334+         dtype = str ,
32513335    )
32523336
32533337    _shared_docs [
0 commit comments