88 TYPE_CHECKING ,
99 Any ,
1010 Dict ,
11+ List ,
1112 Literal ,
1213 Optional ,
1314 Protocol ,
2324import numpy as np
2425
2526from ._typing import CNumericPtr , DataType , NumpyDType , NumpyOrCupy
26- from .compat import import_cupy , lazy_isinstance
27+ from .compat import import_cupy , import_pyarrow , lazy_isinstance
2728
2829if TYPE_CHECKING :
2930 import pandas as pd
@@ -69,7 +70,11 @@ def shape(self) -> Tuple[int, int]:
6970
7071def array_hasobject (data : DataType ) -> bool :
7172 """Whether the numpy array has object dtype."""
72- return hasattr (data .dtype , "hasobject" ) and data .dtype .hasobject
73+ return (
74+ hasattr (data , "dtype" )
75+ and hasattr (data .dtype , "hasobject" )
76+ and data .dtype .hasobject
77+ )
7378
7479
7580def cuda_array_interface_dict (data : _CudaArrayLikeArg ) -> ArrayInf :
@@ -180,7 +185,7 @@ def is_arrow_dict(data: Any) -> TypeGuard["pa.DictionaryArray"]:
180185 return lazy_isinstance (data , "pyarrow.lib" , "DictionaryArray" )
181186
182187
183- class PdCatAccessor (Protocol ):
188+ class DfCatAccessor (Protocol ):
184189 """Protocol for pandas cat accessor."""
185190
186191 @property
@@ -202,7 +207,7 @@ def to_arrow( # pylint: disable=missing-function-docstring
202207 def __cuda_array_interface__ (self ) -> ArrayInf : ...
203208
204209
205- def _is_pd_cat (data : Any ) -> TypeGuard [PdCatAccessor ]:
210+ def _is_df_cat (data : Any ) -> TypeGuard [DfCatAccessor ]:
206211 # Test pd.Series.cat, not pd.Series
207212 return hasattr (data , "categories" ) and hasattr (data , "codes" )
208213
@@ -234,6 +239,67 @@ def npstr_to_arrow_strarr(strarr: np.ndarray) -> Tuple[np.ndarray, str]:
234239 return offsets .astype (np .int32 ), values
235240
236241
242+ def _arrow_cat_inf ( # pylint: disable=too-many-locals
243+ cats : "pa.StringArray" ,
244+ codes : Union [_ArrayLikeArg , _CudaArrayLikeArg , "pa.IntegerArray" ],
245+ ) -> Tuple [StringArray , ArrayInf , Tuple ]:
246+ if not TYPE_CHECKING :
247+ pa = import_pyarrow ()
248+
249+ # FIXME(jiamingy): Account for offset, need to find an implementation that returns
250+ # offset > 0
251+ assert cats .offset == 0
252+ buffers : List [pa .Buffer ] = cats .buffers ()
253+ mask , offset , data = buffers
254+ assert offset .is_cpu
255+
256+ off_len = len (cats ) + 1
257+ if offset .size != off_len * (np .iinfo (np .int32 ).bits / 8 ):
258+ raise TypeError ("Arrow dictionary type offsets is required to be 32 bit." )
259+
260+ joffset : ArrayInf = {
261+ "data" : (offset .address , True ),
262+ "typestr" : "<i4" ,
263+ "version" : 3 ,
264+ "strides" : None ,
265+ "shape" : (off_len ,),
266+ "mask" : None ,
267+ }
268+
269+ def make_buf_inf (buf : pa .Buffer , typestr : str ) -> ArrayInf :
270+ return {
271+ "data" : (buf .address , True ),
272+ "typestr" : typestr ,
273+ "version" : 3 ,
274+ "strides" : None ,
275+ "shape" : (buf .size ,),
276+ "mask" : None ,
277+ }
278+
279+ jdata = make_buf_inf (data , "<i1" )
280+ # Categories should not have missing values.
281+ assert mask is None
282+
283+ jnames : StringArray = {"offsets" : joffset , "values" : jdata }
284+
285+ def make_array_inf (
286+ array : Any ,
287+ ) -> Tuple [ArrayInf , Optional [Tuple [pa .Buffer , pa .Buffer ]]]:
288+ """Helper for handling categorical codes."""
289+ # Handle cuDF data
290+ if hasattr (array , "__cuda_array_interface__" ):
291+ inf = cuda_array_interface_dict (array )
292+ return inf , None
293+
294+ # Other types (like arrow itself) are not yet supported.
295+ raise TypeError ("Invalid input type." )
296+
297+ cats_tmp = (mask , offset , data )
298+ jcodes , codes_tmp = make_array_inf (codes )
299+
300+ return jnames , jcodes , (cats_tmp , codes_tmp )
301+
302+
237303def _ensure_np_dtype (
238304 data : DataType , dtype : Optional [NumpyDType ]
239305) -> Tuple [np .ndarray , Optional [NumpyDType ]]:
@@ -252,7 +318,7 @@ def array_interface_dict(data: np.ndarray) -> ArrayInf: ...
252318
253319@overload
254320def array_interface_dict (
255- data : PdCatAccessor ,
321+ data : DfCatAccessor ,
256322) -> Tuple [StringArray , ArrayInf , Tuple ]: ...
257323
258324
@@ -263,11 +329,11 @@ def array_interface_dict(
263329
264330
265331def array_interface_dict ( # pylint: disable=too-many-locals
266- data : Union [np .ndarray , PdCatAccessor ],
332+ data : Union [np .ndarray , DfCatAccessor ],
267333) -> Union [ArrayInf , Tuple [StringArray , ArrayInf , Optional [Tuple ]]]:
268334 """Returns an array interface from the input."""
269335 # Handle categorical values
270- if _is_pd_cat (data ):
336+ if _is_df_cat (data ):
271337 cats = data .categories
272338 # pandas uses -1 to represent missing values for categorical features
273339 codes = data .codes .replace (- 1 , np .nan )
@@ -287,6 +353,7 @@ def array_interface_dict( # pylint: disable=too-many-locals
287353 name_offsets , _ = _ensure_np_dtype (name_offsets , np .int32 )
288354 joffsets = array_interface_dict (name_offsets )
289355 bvalues = name_values .encode ("utf-8" )
356+
290357 ptr = ctypes .c_void_p .from_buffer (ctypes .c_char_p (bvalues )).value
291358 assert ptr is not None
292359
@@ -335,3 +402,20 @@ def check_cudf_meta(data: _CudaArrayLikeArg, field: str) -> None:
335402 and data .__cuda_array_interface__ ["mask" ] is not None
336403 ):
337404 raise ValueError (f"Missing value is not allowed for: { field } " )
405+
406+
407+ def cudf_cat_inf (
408+ cats : DfCatAccessor , codes : "pd.Series"
409+ ) -> Tuple [Union [ArrayInf , StringArray ], ArrayInf , Tuple ]:
410+ """Obtain the cuda array interface for cuDF categories."""
411+ cp = import_cupy ()
412+ is_num_idx = cp .issubdtype (cats .dtype , cp .floating ) or cp .issubdtype (
413+ cats .dtype , cp .integer
414+ )
415+ if is_num_idx :
416+ cats_ainf = cats .__cuda_array_interface__
417+ codes_ainf = cuda_array_interface_dict (codes )
418+ return cats_ainf , codes_ainf , (cats , codes )
419+
420+ joffset , jdata , buf = _arrow_cat_inf (cats .to_arrow (), codes )
421+ return joffset , jdata , buf
0 commit comments