99
1010import numpy as np
1111
12- from pandas ._config import get_option
12+ from pandas ._config import (
13+ get_option ,
14+ using_string_dtype ,
15+ )
1316
1417from pandas ._libs import (
1518 lib ,
@@ -81,8 +84,10 @@ class StringDtype(StorageExtensionDtype):
8184
8285 Parameters
8386 ----------
84- storage : {"python", "pyarrow", "pyarrow_numpy" }, optional
87+ storage : {"python", "pyarrow"}, optional
8588 If not given, the value of ``pd.options.mode.string_storage``.
89+ na_value : {np.nan, pd.NA}, default pd.NA
90+ Whether the dtype follows NaN or NA missing value semantics.
8691
8792 Attributes
8893 ----------
@@ -113,30 +118,67 @@ class StringDtype(StorageExtensionDtype):
113118 # follows NumPy semantics, which uses nan.
114119 @property
115120 def na_value (self ) -> libmissing .NAType | float : # type: ignore[override]
116- if self .storage == "pyarrow_numpy" :
117- return np .nan
118- else :
119- return libmissing .NA
121+ return self ._na_value
120122
121- _metadata = ("storage" ,)
123+ _metadata = ("storage" , "_na_value" ) # type: ignore[assignment]
122124
123- def __init__ (self , storage = None ) -> None :
125+ def __init__ (
126+ self ,
127+ storage : str | None = None ,
128+ na_value : libmissing .NAType | float = libmissing .NA ,
129+ ) -> None :
130+ # infer defaults
124131 if storage is None :
125- infer_string = get_option ("future.infer_string" )
126- if infer_string :
127- storage = "pyarrow_numpy"
132+ if using_string_dtype ():
133+ storage = "pyarrow"
128134 else :
129135 storage = get_option ("mode.string_storage" )
130- if storage not in {"python" , "pyarrow" , "pyarrow_numpy" }:
136+
137+ if storage == "pyarrow_numpy" :
138+ # TODO raise a deprecation warning
139+ storage = "pyarrow"
140+ na_value = np .nan
141+
142+ # validate options
143+ if storage not in {"python" , "pyarrow" }:
131144 raise ValueError (
132- f"Storage must be 'python', 'pyarrow' or 'pyarrow_numpy'. "
133- f"Got { storage } instead."
145+ f"Storage must be 'python' or 'pyarrow'. Got { storage } instead."
134146 )
135- if storage in ( "pyarrow" , "pyarrow_numpy" ) and pa_version_under10p1 :
147+ if storage == "pyarrow" and pa_version_under10p1 :
136148 raise ImportError (
137149 "pyarrow>=10.0.1 is required for PyArrow backed StringArray."
138150 )
151+
152+ if isinstance (na_value , float ) and np .isnan (na_value ):
153+ # when passed a NaN value, always set to np.nan to ensure we use
154+ # a consistent NaN value (and we can use `dtype.na_value is np.nan`)
155+ na_value = np .nan
156+ elif na_value is not libmissing .NA :
157+ raise ValueError ("'na_value' must be np.nan or pd.NA, got {na_value}" )
158+
139159 self .storage = storage
160+ self ._na_value = na_value
161+
162+ def __eq__ (self , other : object ) -> bool :
163+ # we need to override the base class __eq__ because na_value (NA or NaN)
164+ # cannot be checked with normal `==`
165+ if isinstance (other , str ):
166+ if other == self .name :
167+ return True
168+ try :
169+ other = self .construct_from_string (other )
170+ except TypeError :
171+ return False
172+ if isinstance (other , type (self )):
173+ return self .storage == other .storage and self .na_value is other .na_value
174+ return False
175+
176+ def __hash__ (self ) -> int :
177+ # need to override __hash__ as well because of overriding __eq__
178+ return super ().__hash__ ()
179+
180+ def __reduce__ (self ):
181+ return StringDtype , (self .storage , self .na_value )
140182
141183 @property
142184 def type (self ) -> type [str ]:
@@ -181,6 +223,7 @@ def construct_from_string(cls, string) -> Self:
181223 elif string == "string[pyarrow]" :
182224 return cls (storage = "pyarrow" )
183225 elif string == "string[pyarrow_numpy]" :
226+ # TODO deprecate
184227 return cls (storage = "pyarrow_numpy" )
185228 else :
186229 raise TypeError (f"Cannot construct a '{ cls .__name__ } ' from '{ string } '" )
@@ -205,7 +248,7 @@ def construct_array_type( # type: ignore[override]
205248
206249 if self .storage == "python" :
207250 return StringArray
208- elif self .storage == "pyarrow" :
251+ elif self .storage == "pyarrow" and self . _na_value is libmissing . NA :
209252 return ArrowStringArray
210253 else :
211254 return ArrowStringArrayNumpySemantics
@@ -217,13 +260,17 @@ def __from_arrow__(
217260 Construct StringArray from pyarrow Array/ChunkedArray.
218261 """
219262 if self .storage == "pyarrow" :
220- from pandas .core .arrays .string_arrow import ArrowStringArray
263+ if self ._na_value is libmissing .NA :
264+ from pandas .core .arrays .string_arrow import ArrowStringArray
265+
266+ return ArrowStringArray (array )
267+ else :
268+ from pandas .core .arrays .string_arrow import (
269+ ArrowStringArrayNumpySemantics ,
270+ )
221271
222- return ArrowStringArray (array )
223- elif self .storage == "pyarrow_numpy" :
224- from pandas .core .arrays .string_arrow import ArrowStringArrayNumpySemantics
272+ return ArrowStringArrayNumpySemantics (array )
225273
226- return ArrowStringArrayNumpySemantics (array )
227274 else :
228275 import pyarrow
229276
0 commit comments