From 2df528d50d2353d0e54c3ea546cf500d19dbd555 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 5 Mar 2025 19:34:48 +0800 Subject: [PATCH 01/12] Store categories from cuDF. - Add glue code for cuDF through arrow. --- python-package/xgboost/_data_utils.py | 103 +++++++++++++- python-package/xgboost/_typing.py | 4 +- python-package/xgboost/core.py | 26 ++-- python-package/xgboost/data.py | 155 ++++++++++++---------- python-package/xgboost/testing/data.py | 8 ++ python-package/xgboost/testing/ordinal.py | 56 ++++++-- src/data/device_adapter.cu | 60 +++++++++ src/data/device_adapter.cuh | 72 +++++----- src/data/simple_dmatrix.cu | 13 +- tests/python-gpu/test_gpu_ordinal.py | 14 ++ tests/python/test_ordinal.py | 2 +- 11 files changed, 364 insertions(+), 149 deletions(-) create mode 100644 src/data/device_adapter.cu create mode 100644 tests/python-gpu/test_gpu_ordinal.py diff --git a/python-package/xgboost/_data_utils.py b/python-package/xgboost/_data_utils.py index b56aec341b30..85f2fe426520 100644 --- a/python-package/xgboost/_data_utils.py +++ b/python-package/xgboost/_data_utils.py @@ -8,6 +8,7 @@ TYPE_CHECKING, Any, Dict, + List, Literal, Optional, Protocol, @@ -23,7 +24,7 @@ import numpy as np from ._typing import CNumericPtr, DataType, NumpyDType, NumpyOrCupy -from .compat import import_cupy, lazy_isinstance +from .compat import import_cupy, import_pyarrow, lazy_isinstance if TYPE_CHECKING: import pandas as pd @@ -69,7 +70,11 @@ def shape(self) -> Tuple[int, int]: def array_hasobject(data: DataType) -> bool: """Whether the numpy array has object dtype.""" - return hasattr(data.dtype, "hasobject") and data.dtype.hasobject + return ( + hasattr(data, "dtype") + and hasattr(data.dtype, "hasobject") + and data.dtype.hasobject + ) def cuda_array_interface_dict(data: _CudaArrayLikeArg) -> ArrayInf: @@ -202,7 +207,7 @@ def to_arrow( # pylint: disable=missing-function-docstring def __cuda_array_interface__(self) -> ArrayInf: ... -def _is_pd_cat(data: Any) -> TypeGuard[PdCatAccessor]: +def _is_df_cat(data: Any) -> TypeGuard[PdCatAccessor]: # Test pd.Series.cat, not pd.Series return hasattr(data, "categories") and hasattr(data, "codes") @@ -234,6 +239,69 @@ def npstr_to_arrow_strarr(strarr: np.ndarray) -> Tuple[np.ndarray, str]: return offsets.astype(np.int32), values +def _arrow_cat_inf( # pylint: disable=too-many-locals + cats: "pa.StringArray", + codes: Union[_ArrayLikeArg, _CudaArrayLikeArg, "pa.IntegerArray"], +) -> Tuple[StringArray, ArrayInf, Tuple]: + if not TYPE_CHECKING: + pa = import_pyarrow() + + # FIXME(jiamingy): Account for offset, need to find an implementation that returns + # offset > 0 + assert cats.offset == 0 + buffers: List[pa.Buffer] = cats.buffers() + mask, offset, data = buffers + assert offset.is_cpu + + off_len = len(cats) + 1 + if offset.size != off_len * (np.iinfo(np.int32).bits / 8): + raise TypeError("Arrow dictionary type offsets is required to be 32 bit.") + + joffset: ArrayInf = { + "data": (offset.address, True), + "typestr": " ArrayInf: + return { + "data": (buf.address, True), + "typestr": typestr, + "version": 3, + "strides": None, + "shape": (buf.size,), + "mask": None, + } + + jdata = make_buf_inf(data, " Tuple[ArrayInf, Optional[Tuple[pa.Buffer, pa.Buffer]]]: + """Helper for handling categorical codes.""" + # Handle cuDF data + if hasattr(array, "__cuda_array_interface__"): + inf = array.__cuda_array_interface__ + if "mask" in inf: + inf["mask"] = inf["mask"].__cuda_array_interface__ + return inf, None + + # Other types (like arrow itself) are not yet supported. + raise TypeError("Invalid input type.") + + cats_tmp = (mask, offset, data) + jcodes, codes_tmp = make_array_inf(codes) + + return jnames, jcodes, (cats_tmp, codes_tmp) + + def _ensure_np_dtype( data: DataType, dtype: Optional[NumpyDType] ) -> Tuple[np.ndarray, Optional[NumpyDType]]: @@ -267,7 +335,12 @@ def array_interface_dict( # pylint: disable=too-many-locals ) -> Union[ArrayInf, Tuple[StringArray, ArrayInf, Optional[Tuple]]]: """Returns an array interface from the input.""" # Handle categorical values - if _is_pd_cat(data): + if is_arrow_dict(data): + cats = data.dictionary + codes = data.indices + jnames, jcodes, buf = _arrow_cat_inf(cats, codes) + return jnames, jcodes, buf + if _is_df_cat(data): cats = data.categories # pandas uses -1 to represent missing values for categorical features codes = data.codes.replace(-1, np.nan) @@ -287,6 +360,7 @@ def array_interface_dict( # pylint: disable=too-many-locals name_offsets, _ = _ensure_np_dtype(name_offsets, np.int32) joffsets = array_interface_dict(name_offsets) bvalues = name_values.encode("utf-8") + ptr = ctypes.c_void_p.from_buffer(ctypes.c_char_p(bvalues)).value assert ptr is not None @@ -298,7 +372,7 @@ def array_interface_dict( # pylint: disable=too-many-locals "version": 3, "mask": None, } - jnames: StringArray = {"offsets": joffsets, "values": jvalues} + jnames = {"offsets": joffsets, "values": jvalues} code_values = codes.values jcodes = array_interface_dict(code_values) @@ -335,3 +409,22 @@ def check_cudf_meta(data: _CudaArrayLikeArg, field: str) -> None: and data.__cuda_array_interface__["mask"] is not None ): raise ValueError(f"Missing value is not allowed for: {field}") + + +def cudf_cat_inf( + cats: PdCatAccessor, codes: "pd.Series" +) -> Tuple[Union[ArrayInf, StringArray], ArrayInf, Tuple]: + """Obtain the cuda array interface for cuDF categories.""" + cp = import_cupy() + is_num_idx = cp.issubdtype(cats.dtype, cp.floating) or cp.issubdtype( + cats.dtype, cp.integer + ) + if is_num_idx: + cats_ainf = cats.__cuda_array_interface__ + codes_ainf = codes.__cuda_array_interface__ + if "mask" in codes_ainf: + codes_ainf["mask"] = codes_ainf["mask"].__cuda_array_interface__ + return cats_ainf, codes_ainf, (cats, codes) + + joffset, jdata, buf = _arrow_cat_inf(cats.to_arrow(), codes) + return joffset, jdata, buf diff --git a/python-package/xgboost/_typing.py b/python-package/xgboost/_typing.py index bf8186417913..97536ccfb2cb 100644 --- a/python-package/xgboost/_typing.py +++ b/python-package/xgboost/_typing.py @@ -109,9 +109,7 @@ # The second arg is actually Optional[List[cudf.Series]], skipped for easier type check. # The cudf Series is the obtained cat codes, preserved in the `DataIter` to prevent it # being freed. -TransformedData = Tuple[ - Any, Optional[List], Optional[FeatureNames], Optional[FeatureTypes] -] +TransformedData = Tuple[Any, Optional[FeatureNames], Optional[FeatureTypes]] # template parameter _T = TypeVar("_T") diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 98a00b664bb2..ee0f1c905276 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -626,17 +626,17 @@ def input_data( and ref is not None and ref is self._data_ref ): - new, cat_codes, feature_names, feature_types = self._temporary_data + new, feature_names, feature_types = self._temporary_data else: - new, cat_codes, feature_names, feature_types = _proxy_transform( + new, feature_names, feature_types = _proxy_transform( data, feature_names, feature_types, self._enable_categorical, ) # Stage the data, meta info are copied inside C++ MetaInfo. - self._temporary_data = (new, cat_codes, feature_names, feature_types) - dispatch_proxy_set_data(self.proxy, new, cat_codes) + self._temporary_data = (new, feature_names, feature_types) + dispatch_proxy_set_data(self.proxy, new) self.proxy.set_info( feature_names=feature_names, feature_types=feature_types, @@ -1525,12 +1525,11 @@ def _ref_data_from_cuda_interface(self, data: DataType) -> None: arrinf = cuda_array_interface(data) _check_call(_LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, arrinf)) - def _ref_data_from_cuda_columnar(self, data: DataType, cat_codes: list) -> None: + def _ref_data_from_cuda_columnar(self, data: TransformedDf) -> None: """Reference data from CUDA columnar format.""" - from .data import _cudf_array_interfaces - - interfaces_str = _cudf_array_interfaces(data, cat_codes) - _check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str)) + _check_call( + _LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, data.array_interface()) + ) def _ref_data_from_array(self, data: np.ndarray) -> None: """Reference data from numpy array.""" @@ -2822,18 +2821,15 @@ def inplace_predict( ) return _prediction_output(shape, dims, preds, True) if _is_cudf_df(data): - from .data import _cudf_array_interfaces, _transform_cudf_df + from .data import _transform_cudf_df - data, cat_codes, fns, _ = _transform_cudf_df( - data, None, None, enable_categorical - ) - interfaces_str = _cudf_array_interfaces(data, cat_codes) + df, fns, _ = _transform_cudf_df(data, None, None, enable_categorical) if validate_features: self._validate_features(fns) _check_call( _LIB.XGBoosterPredictFromCudaColumnar( self.handle, - interfaces_str, + df.array_interface(), args, p_handle, ctypes.byref(shape), diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 039fa35e88fa..ea95df7556d4 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -18,7 +18,6 @@ Type, TypeGuard, Union, - cast, ) import numpy as np @@ -26,14 +25,17 @@ from ._data_utils import ( ArrayInf, PdCatAccessor, + StringArray, TransformedDf, _ensure_np_dtype, - _is_pd_cat, + _is_df_cat, array_hasobject, array_interface, array_interface_dict, check_cudf_meta, cuda_array_interface, + cuda_array_interface_dict, + cudf_cat_inf, is_arrow_dict, make_array_interface, ) @@ -64,7 +66,6 @@ _check_call, _ProxyDMatrix, c_str, - from_pystr_to_cstr, make_jcargs, ) @@ -638,9 +639,9 @@ def array_interface(self) -> bytes: def shape(self) -> Tuple[int, int]: """Return shape of the transformed DataFrame.""" if is_arrow_dict(self.columns[0]): - # When input is arrow. (cuDF) + # When input is arrow. n_samples = len(self.columns[0].indices) - elif _is_pd_cat(self.columns[0]): + elif _is_df_cat(self.columns[0]): # When input is pandas. n_samples = self.columns[0].codes.shape[0] else: @@ -1056,37 +1057,62 @@ def is_categorical_dtype(dtype: Any) -> bool: return is_categorical_dtype -def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes: - """Extract CuDF __cuda_array_interface__. This is special as it returns a new list - of data and a list of array interfaces. The data is list of categorical codes that - caller can safely ignore, but have to keep their reference alive until usage of - array interface is finished. +@functools.cache +def _lazy_load_cudf_is_bool() -> Callable[[Any], bool]: + from cudf.api.types import is_bool_dtype - """ - is_categorical_dtype = _lazy_load_cudf_is_cat() - interfaces = [] + return is_bool_dtype - def append(interface: dict) -> None: - if "mask" in interface: - interface["mask"] = interface["mask"].__cuda_array_interface__ - interfaces.append(interface) - if _is_cudf_ser(data): - if is_categorical_dtype(data.dtype): - interface = cat_codes[0].__cuda_array_interface__ - else: - interface = data.__cuda_array_interface__ - append(interface) - else: - for i, col in enumerate(data): - if is_categorical_dtype(data[col].dtype): - codes = cat_codes[i] - interface = codes.__cuda_array_interface__ +class CudfTransformed(TransformedDf): + """A storage class for transformed cuDF dataframe.""" + + def __init__(self, columns: List[Union["PdSeries", PdCatAccessor]]) -> None: + self.columns = columns + # Buffers for temporary data that cannot be freed until the data is consumed by + # the DMatrix or the booster. + self.temporary_buffers: List[Tuple] = [] + + aitfs: List[ + Union[ + ArrayInf, # numeric column + Tuple[ # categorical column + Union[ArrayInf, StringArray], # string index, numeric index + ArrayInf, # codes + ], + ] + ] = [] + + def push_series(ser: Any) -> None: + if _is_df_cat(ser): + cats, codes = ser.categories, ser.codes + cats_ainf: Union[StringArray, ArrayInf] # string or numeric index + cats_ainf, codes_ainf, buf = cudf_cat_inf(cats, codes) + self.temporary_buffers.append(buf) + aitfs.append((cats_ainf, codes_ainf)) else: - interface = data[col].__cuda_array_interface__ - append(interface) - interfaces_str = from_pystr_to_cstr(json.dumps(interfaces)) - return interfaces_str + # numeric column + ainf = cuda_array_interface_dict(ser) + aitfs.append(ainf) + + for col in self.columns: + push_series(col) + + self.aitfs = aitfs + + def array_interface(self) -> bytes: + """Return a byte string for JSON encoded array interface.""" + sarrays = bytes(json.dumps(self.aitfs), "utf-8") + return sarrays + + @property + def shape(self) -> Tuple[int, int]: + """Return shape of the transformed DataFrame.""" + if _is_df_cat(self.columns[0]): + n_samples = self.columns[0].codes.shape[0] + else: + n_samples = self.columns[0].shape[0] # type: ignore + return n_samples, len(self.columns) def _transform_cudf_df( @@ -1094,26 +1120,23 @@ def _transform_cudf_df( feature_names: Optional[FeatureNames], feature_types: Optional[FeatureTypes], enable_categorical: bool, -) -> Tuple[ctypes.c_void_p, list, Optional[FeatureNames], Optional[FeatureTypes]]: - - try: - from cudf.api.types import is_bool_dtype - except ImportError: - from pandas.api.types import is_bool_dtype +) -> Tuple[ + CudfTransformed, + Optional[FeatureNames], + Optional[FeatureTypes], +]: + is_bool_dtype = _lazy_load_cudf_is_bool() is_categorical_dtype = _lazy_load_cudf_is_cat() # Work around https://github.com/dmlc/xgboost/issues/10181 if _is_cudf_ser(data): if is_bool_dtype(data.dtype): data = data.astype(np.uint8) + dtypes = [data.dtype] else: data = data.astype( {col: np.uint8 for col in data.select_dtypes(include="bool")} ) - - if _is_cudf_ser(data): - dtypes = [data.dtype] - else: dtypes = data.dtypes if not all( @@ -1142,24 +1165,26 @@ def _transform_cudf_df( feature_types.append(_pandas_dtype_mapper[dtype.name]) # handle categorical data - cat_codes = [] + result = [] if _is_cudf_ser(data): # unlike pandas, cuDF uses NA for missing data. if is_categorical_dtype(data.dtype) and enable_categorical: - codes = data.cat.codes - cat_codes.append(codes) + result.append(data.cat) + elif enable_categorical: + raise ValueError(_ENABLE_CAT_ERR) + else: + result.append(data) else: for col in data: dtype = data[col].dtype if is_categorical_dtype(dtype) and enable_categorical: - codes = data[col].cat.codes - cat_codes.append(codes) + result.append(data[col].cat) elif is_categorical_dtype(dtype): raise ValueError(_ENABLE_CAT_ERR) else: - cat_codes.append([]) + result.append(data[col]) - return data, cat_codes, feature_names, feature_types + return CudfTransformed(result), feature_names, feature_types def _from_cudf_df( @@ -1171,14 +1196,13 @@ def _from_cudf_df( feature_types: Optional[FeatureTypes], enable_categorical: bool, ) -> DispatchedDataBackendReturnType: - data, cat_codes, feature_names, feature_types = _transform_cudf_df( + df, feature_names, feature_types = _transform_cudf_df( data, feature_names, feature_types, enable_categorical ) - interfaces_str = _cudf_array_interfaces(data, cat_codes) handle = ctypes.c_void_p() _check_call( _LIB.XGDMatrixCreateFromCudaColumnar( - interfaces_str, + df.array_interface(), make_jcargs(nthread=nthread, missing=missing), ctypes.byref(handle), ) @@ -1694,28 +1718,28 @@ def _proxy_transform( ) if _is_cupy_alike(data): data = _transform_cupy_array(data) - return data, None, feature_names, feature_types + return data, feature_names, feature_types if _is_dlpack(data): - return _transform_dlpack(data), None, feature_names, feature_types + return _transform_dlpack(data), feature_names, feature_types if _is_list(data) or _is_tuple(data): data = np.array(data) if _is_np_array_like(data): data, _ = _ensure_np_dtype(data, data.dtype) - return data, None, feature_names, feature_types + return data, feature_names, feature_types if is_scipy_csr(data): data = transform_scipy_sparse(data, True) - return data, None, feature_names, feature_types + return data, feature_names, feature_types if is_scipy_csc(data): data = transform_scipy_sparse(data.tocsr(), True) - return data, None, feature_names, feature_types + return data, feature_names, feature_types if is_scipy_coo(data): data = transform_scipy_sparse(data.tocsr(), True) - return data, None, feature_names, feature_types + return data, feature_names, feature_types if _is_polars(data): df_pl, feature_names, feature_types = _transform_polars_df( data, enable_categorical, feature_names, feature_types ) - return df_pl, None, feature_names, feature_types + return df_pl, feature_names, feature_types if _is_pandas_series(data): import pandas as pd @@ -1724,12 +1748,12 @@ def _proxy_transform( df_pa, feature_names, feature_types = _transform_arrow_table( data, enable_categorical, feature_names, feature_types ) - return df_pa, None, feature_names, feature_types + return df_pa, feature_names, feature_types if _is_pandas_df(data): df, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types ) - return df, None, feature_names, feature_types + return df, feature_names, feature_types raise TypeError("Value type is not supported for data iterator:" + str(type(data))) @@ -1741,7 +1765,6 @@ def is_on_cuda(data: Any) -> bool: def dispatch_proxy_set_data( proxy: _ProxyDMatrix, data: DataType, - cat_codes: Optional[list], ) -> None: """Dispatch for QuantileDMatrix.""" if ( @@ -1751,13 +1774,9 @@ def dispatch_proxy_set_data( ): _check_data_shape(data) - if _is_cudf_df(data): - # pylint: disable=W0212 - proxy._ref_data_from_cuda_columnar(data, cast(List, cat_codes)) - return - if _is_cudf_ser(data): + if isinstance(data, CudfTransformed): # pylint: disable=W0212 - proxy._ref_data_from_cuda_columnar(data, cast(List, cat_codes)) + proxy._ref_data_from_cuda_columnar(data) return if _is_cupy_alike(data): proxy._ref_data_from_cuda_interface(data) # pylint: disable=W0212 diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index 4ad5915aad88..de6725ae58de 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -965,6 +965,7 @@ def make_categorical( shuffle: bool = False, random_state: int = 1994, cat_dtype: np.typing.DTypeLike = np.int64, + device: str = "cpu", ) -> Tuple[ArrayLike, np.ndarray]: """Generate categorical features for test. @@ -1034,4 +1035,11 @@ def make_categorical( rng.shuffle(columns) df = df[columns] + if device != "cpu": + assert device in ["cuda", "gpu"] + import cudf + import cupy + + df = cudf.from_pandas(df) + label = cupy.array(label) return df, label diff --git a/python-package/xgboost/testing/ordinal.py b/python-package/xgboost/testing/ordinal.py index dafd06abec13..f4925bdb1313 100644 --- a/python-package/xgboost/testing/ordinal.py +++ b/python-package/xgboost/testing/ordinal.py @@ -3,7 +3,7 @@ import os import tempfile -from typing import Any, Tuple, Type +from typing import Any, Literal, Tuple, Type import numpy as np @@ -36,7 +36,7 @@ def assert_allclose(device: str, a: Any, b: Any) -> None: cp.testing.assert_allclose(a, b) -def run_cat_container(device: str) -> None: +def run_cat_container(device: Literal["cpu", "cuda"]) -> None: """Basic tests for the container class used by the DMatrix.""" Df, _ = get_df_impl(device) # Basic test with a single feature @@ -75,20 +75,30 @@ def run_cat_container(device: str) -> None: Xy = DMatrix(df, enable_categorical=True) -def run_cat_container_mixed() -> None: +def run_cat_container_mixed(device: Literal["cpu", "cuda"]) -> None: """Run checks with mixed types.""" import pandas as pd + from ..data import _lazy_load_cudf_is_cat + + is_cudf_cat = _lazy_load_cudf_is_cat() + n_samples = int(2**10) + def check(Xy: DMatrix, X: pd.DataFrame) -> None: cats = Xy.get_categories() assert cats is not None for fname in X.columns: - if is_pd_cat_dtype(X[fname].dtype): + if is_pd_cat_dtype(X[fname].dtype) or is_cudf_cat(X[fname].dtype): aw_list = sorted(cats[fname].to_pylist()) - pd_list: list = X[fname].unique().tolist() - if np.nan in pd_list: + if is_cudf_cat(X[fname].dtype): + pd_list: list = X[fname].unique().to_arrow().to_pylist() + else: + pd_list = X[fname].unique().tolist() + if np.nan in pd_list: # pandas pd_list.remove(np.nan) + if None in pd_list: # cudf + pd_list.remove(None) pd_list = sorted(pd_list) assert aw_list == pd_list else: @@ -110,35 +120,57 @@ def check(Xy: DMatrix, X: pd.DataFrame) -> None: assert v_0.to_pylist() == v_1.to_pylist() # full str type - X, y = make_categorical(256, 16, 7, onehot=False, cat_dtype=np.str_) + X, y = make_categorical( + n_samples, 16, 7, onehot=False, cat_dtype=np.str_, device=device + ) Xy = DMatrix(X, y, enable_categorical=True) check(Xy, X) # str type, mixed with numerical features - X, y = make_categorical(256, 16, 7, onehot=False, cat_ratio=0.5, cat_dtype=np.str_) + X, y = make_categorical( + n_samples, 16, 7, onehot=False, cat_ratio=0.5, cat_dtype=np.str_, device=device + ) Xy = DMatrix(X, y, enable_categorical=True) check(Xy, X) # str type, mixed with numerical features and missing values X, y = make_categorical( - 256, 16, 7, onehot=False, cat_ratio=0.5, sparsity=0.5, cat_dtype=np.str_ + n_samples, + 16, + 7, + onehot=False, + cat_ratio=0.5, + sparsity=0.5, + cat_dtype=np.str_, + device=device, ) Xy = DMatrix(X, y, enable_categorical=True) check(Xy, X) # int type - X, y = make_categorical(256, 16, 7, onehot=False, cat_dtype=np.int64) + X, y = make_categorical( + n_samples, 16, 7, onehot=False, cat_dtype=np.int64, device=device + ) Xy = DMatrix(X, y, enable_categorical=True) check(Xy, X) # int type, mixed with numerical features - X, y = make_categorical(256, 16, 7, onehot=False, cat_ratio=0.5, cat_dtype=np.int64) + X, y = make_categorical( + n_samples, 16, 7, onehot=False, cat_ratio=0.5, cat_dtype=np.int64, device=device + ) Xy = DMatrix(X, y, enable_categorical=True) check(Xy, X) # int type, mixed with numerical features and missing values X, y = make_categorical( - 256, 16, 7, onehot=False, cat_ratio=0.5, sparsity=0.5, cat_dtype=np.int64 + n_samples, + 16, + 7, + onehot=False, + cat_ratio=0.5, + sparsity=0.5, + cat_dtype=np.int64, + device=device, ) Xy = DMatrix(X, y, enable_categorical=True) check(Xy, X) diff --git a/src/data/device_adapter.cu b/src/data/device_adapter.cu new file mode 100644 index 000000000000..0d37e2095959 --- /dev/null +++ b/src/data/device_adapter.cu @@ -0,0 +1,60 @@ +/** + * Copyright 2019-2025, XGBoost Contributors + */ +#include "../common/cuda_rt_utils.h" +#include "device_adapter.cuh" + +namespace xgboost::data { +CudfAdapter::CudfAdapter(StringView cuda_arrinf) { + Json interfaces = Json::Load(cuda_arrinf); + std::vector const& jcolumns = get(interfaces); + std::size_t n_columns = jcolumns.size(); + CHECK_GT(n_columns, 0) << "The number of columns must not equal to 0."; + + std::vector> columns; + std::vector cat_segments{0}; + std::int32_t device = -1; + for (auto const& jcol : jcolumns) { + std::int32_t n_cats{0}; + if (IsA(jcol)) { + // This is a dictionary type (categorical values). + auto const& first = get(jcol[0]); + if (first.find("offsets") == first.cend()) { + // numeric index + n_cats = + GetArrowNumericIndex(DeviceOrd::CUDA(0), jcol, &cats_, &columns, &n_bytes_, &num_rows_); + } else { + // string index + n_cats = GetArrowDictionary(jcol, &cats_, &columns, &n_bytes_, &num_rows_); + } + } else { + // Numeric values + auto col = ArrayInterface<1>(get(jcol)); + columns.push_back(col); + this->cats_.emplace_back(); + this->num_rows_ = std::max(num_rows_, col.Shape<0>()); + CHECK_EQ(num_rows_, col.Shape<0>()) << "All columns should have same number of rows."; + n_bytes_ += col.ElementSize() * col.Shape<0>(); + } + cat_segments.emplace_back(n_cats); + if (device == -1) { + device = dh::CudaGetPointerDevice(columns.back().data); + } + CHECK_EQ(device, dh::CudaGetPointerDevice(columns.back().data)) + << "All columns should use the same device."; + } + // Categories + std::partial_sum(cat_segments.cbegin(), cat_segments.cend(), cat_segments.begin()); + this->n_total_cats_ = cat_segments.back(); + this->cat_segments_ = std::move(cat_segments); + this->d_cats_ = this->cats_; // thrust copy + + CHECK(!columns.empty()); + device_ = DeviceOrd::CUDA(dh::CudaGetPointerDevice(columns.front().data)); + CHECK(device_.IsCUDA()); + curt::SetDevice(device_.ordinal); + + this->columns_ = columns; + batch_ = CudfAdapterBatch(dh::ToSpan(columns_), num_rows_); +} +} // namespace xgboost::data diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index cad3cffbc58a..6203435b8c95 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -1,20 +1,24 @@ /** - * Copyright 2019-2024, XGBoost Contributors - * \file device_adapter.cuh + * Copyright 2019-2025, XGBoost Contributors + * @file device_adapter.cuh */ #ifndef XGBOOST_DATA_DEVICE_ADAPTER_H_ #define XGBOOST_DATA_DEVICE_ADAPTER_H_ + +#include // for maximum #include // for make_counting_iterator #include // for none_of -#include // for size_t -#include -#include +#include // for size_t +#include // for variant +#include // for numeric_limits +#include // for string #include "../common/cuda_context.cuh" #include "../common/device_helpers.cuh" #include "adapter.h" #include "array_interface.h" +#include "xgboost/string_view.h" // for StringView namespace xgboost::data { class CudfAdapterBatch : public detail::NoMetaInfo { @@ -23,8 +27,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { public: CudfAdapterBatch() = default; CudfAdapterBatch(common::Span> columns, size_t num_rows) - : columns_(columns), - num_rows_(num_rows) {} + : columns_(columns), num_rows_(num_rows) {} [[nodiscard]] std::size_t Size() const { return num_rows_ * columns_.size(); } [[nodiscard]] __device__ __forceinline__ COOTuple GetElement(size_t idx) const { size_t column_idx = idx % columns_.size(); @@ -102,37 +105,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { */ class CudfAdapter : public detail::SingleBatchDataIter { public: - explicit CudfAdapter(StringView cuda_interfaces_str) { - Json interfaces = Json::Load(cuda_interfaces_str); - std::vector const& json_columns = get(interfaces); - size_t n_columns = json_columns.size(); - CHECK_GT(n_columns, 0) << "Number of columns must not equal to 0."; - - auto const& typestr = get(json_columns[0]["typestr"]); - CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat(); - std::vector> columns; - auto first_column = ArrayInterface<1>(get(json_columns[0])); - num_rows_ = first_column.Shape<0>(); - if (num_rows_ == 0) { - return; - } - - device_ = DeviceOrd::CUDA(dh::CudaGetPointerDevice(first_column.data)); - CHECK(device_.IsCUDA()); - dh::safe_cuda(cudaSetDevice(device_.ordinal)); - for (auto& json_col : json_columns) { - auto column = ArrayInterface<1>(get(json_col)); - n_bytes_ += column.ElementSize() * column.Shape<0>(); - columns.push_back(column); - num_rows_ = std::max(num_rows_, column.Shape<0>()); - CHECK_EQ(device_.ordinal, dh::CudaGetPointerDevice(column.data)) - << "All columns should use the same device."; - CHECK_EQ(num_rows_, column.Shape<0>()) - << "All columns should have same number of rows."; - } - columns_ = columns; - batch_ = CudfAdapterBatch(dh::ToSpan(columns_), num_rows_); - } + explicit CudfAdapter(StringView cuda_interfaces_str); explicit CudfAdapter(std::string cuda_interfaces_str) : CudfAdapter{StringView{cuda_interfaces_str}} {} @@ -144,11 +117,26 @@ class CudfAdapter : public detail::SingleBatchDataIter { [[nodiscard]] std::size_t NumRows() const { return num_rows_; } [[nodiscard]] std::size_t NumColumns() const { return columns_.size(); } [[nodiscard]] DeviceOrd Device() const { return device_; } - [[nodiscard]] bst_idx_t SizeBytes() const { return this->n_bytes_; } + [[nodiscard]] bst_idx_t SizeBytes() const { return this->n_bytes_; } + + [[nodiscard]] enc::DeviceColumnsView Cats() const { + return {common::Span{this->cats_}, dh::ToSpan(this->cat_segments_), this->n_total_cats_}; + } + [[nodiscard]] enc::DeviceColumnsView DCats() const { + return {dh::ToSpan(this->d_cats_), dh::ToSpan(this->cat_segments_), this->n_total_cats_}; + } + [[nodiscard]] bool HasCategorical() const { return !(n_total_cats_ == 0); } private: CudfAdapterBatch batch_; dh::device_vector> columns_; + + // Categories + std::vector cats_; + dh::device_vector d_cats_; + dh::device_vector cat_segments_; + std::int32_t n_total_cats_{0}; + size_t num_rows_{0}; bst_idx_t n_bytes_{0}; DeviceOrd device_{DeviceOrd::CPU()}; @@ -158,12 +146,12 @@ class CupyAdapterBatch : public detail::NoMetaInfo { public: CupyAdapterBatch() = default; explicit CupyAdapterBatch(ArrayInterface<2> array_interface) - : array_interface_(std::move(array_interface)) {} + : array_interface_(std::move(array_interface)) {} // The total number of elements. [[nodiscard]] std::size_t Size() const { return array_interface_.Shape<0>() * array_interface_.Shape<1>(); } - [[nodiscard]]__device__ COOTuple GetElement(size_t idx) const { + [[nodiscard]] __device__ COOTuple GetElement(size_t idx) const { size_t column_idx = idx % array_interface_.Shape<1>(); size_t row_idx = idx / array_interface_.Shape<1>(); float value = array_interface_(row_idx, column_idx); @@ -202,7 +190,7 @@ class CupyAdapter : public detail::SingleBatchDataIter { [[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape<0>(); } [[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape<1>(); } [[nodiscard]] DeviceOrd Device() const { return device_; } - [[nodiscard]] bst_idx_t SizeBytes() const { return this->n_bytes_; } + [[nodiscard]] bst_idx_t SizeBytes() const { return this->n_bytes_; } private: ArrayInterface<2> array_interface_; diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index fba7ce31cc1a..b2f51671b12b 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -1,9 +1,11 @@ /** - * Copyright 2019-2024, XGBoost Contributors - * \file simple_dmatrix.cu + * Copyright 2019-2025, XGBoost Contributors */ -#include +#include // for int32_t, int8_t +#include // for make_shared + +#include "cat_container.h" // for CatContainer #include "device_adapter.cuh" // for CurrentDevice #include "simple_dmatrix.cuh" #include "simple_dmatrix.h" @@ -42,6 +44,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthr info_.num_col_ = adapter->NumColumns(); info_.num_row_ = adapter->NumRows(); + if constexpr (std::is_same_v) { + if (adapter->HasCategorical()) { + info_.Cats(std::make_shared(adapter->Device(), adapter->Cats())); + } + } this->info_.SynchronizeNumberOfColumns(&ctx, data_split_mode); this->fmat_ctx_ = ctx; diff --git a/tests/python-gpu/test_gpu_ordinal.py b/tests/python-gpu/test_gpu_ordinal.py new file mode 100644 index 000000000000..5ab8a717a219 --- /dev/null +++ b/tests/python-gpu/test_gpu_ordinal.py @@ -0,0 +1,14 @@ +import pytest + +from xgboost import testing as tm +from xgboost.testing.ordinal import run_cat_container, run_cat_container_mixed + +pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_arrow(), tm.no_cudf())) + + +def test_cat_container() -> None: + run_cat_container("cuda") + + +def test_cat_container_mixed() -> None: + run_cat_container_mixed("cuda") diff --git a/tests/python/test_ordinal.py b/tests/python/test_ordinal.py index 837ec883a72d..9c46202cd95a 100644 --- a/tests/python/test_ordinal.py +++ b/tests/python/test_ordinal.py @@ -11,4 +11,4 @@ def test_cat_container() -> None: def test_cat_container_mixed() -> None: - run_cat_container_mixed() + run_cat_container_mixed("cpu") From e47cece1b18e391eb008aa23a58abed9b31b200d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 00:13:43 +0800 Subject: [PATCH 02/12] Cleanup. --- python-package/xgboost/_data_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python-package/xgboost/_data_utils.py b/python-package/xgboost/_data_utils.py index 85f2fe426520..060c62106ecb 100644 --- a/python-package/xgboost/_data_utils.py +++ b/python-package/xgboost/_data_utils.py @@ -335,11 +335,6 @@ def array_interface_dict( # pylint: disable=too-many-locals ) -> Union[ArrayInf, Tuple[StringArray, ArrayInf, Optional[Tuple]]]: """Returns an array interface from the input.""" # Handle categorical values - if is_arrow_dict(data): - cats = data.dictionary - codes = data.indices - jnames, jcodes, buf = _arrow_cat_inf(cats, codes) - return jnames, jcodes, buf if _is_df_cat(data): cats = data.categories # pandas uses -1 to represent missing values for categorical features From 7de280a12a37cb9bfe0bd04653f9e82eadf51c21 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 00:28:11 +0800 Subject: [PATCH 03/12] type. --- python-package/xgboost/_data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/xgboost/_data_utils.py b/python-package/xgboost/_data_utils.py index 060c62106ecb..ed71414adfd1 100644 --- a/python-package/xgboost/_data_utils.py +++ b/python-package/xgboost/_data_utils.py @@ -367,7 +367,7 @@ def array_interface_dict( # pylint: disable=too-many-locals "version": 3, "mask": None, } - jnames = {"offsets": joffsets, "values": jvalues} + jnames: StringArray = {"offsets": joffsets, "values": jvalues} code_values = codes.values jcodes = array_interface_dict(code_values) From f263e48196d9fd49cb2eb8b9371c36745e1f86fd Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 01:02:54 +0800 Subject: [PATCH 04/12] import. --- python-package/xgboost/testing/ordinal.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/testing/ordinal.py b/python-package/xgboost/testing/ordinal.py index f4925bdb1313..6205165fe163 100644 --- a/python-package/xgboost/testing/ordinal.py +++ b/python-package/xgboost/testing/ordinal.py @@ -9,6 +9,7 @@ from ..compat import import_cupy from ..core import DMatrix +from ..data import _lazy_load_cudf_is_cat from .data import is_pd_cat_dtype, make_categorical @@ -79,9 +80,13 @@ def run_cat_container_mixed(device: Literal["cpu", "cuda"]) -> None: """Run checks with mixed types.""" import pandas as pd - from ..data import _lazy_load_cudf_is_cat + try: + is_cudf_cat = _lazy_load_cudf_is_cat() + except ImportError: + + def is_cudf_cat(dtype: Any) -> bool: + return False - is_cudf_cat = _lazy_load_cudf_is_cat() n_samples = int(2**10) def check(Xy: DMatrix, X: pd.DataFrame) -> None: From 61a8fdc4e9a26b224dcd6dfd4416242821935115 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 01:09:26 +0800 Subject: [PATCH 05/12] lint. --- python-package/xgboost/testing/ordinal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/testing/ordinal.py b/python-package/xgboost/testing/ordinal.py index 6205165fe163..86b1ae71018c 100644 --- a/python-package/xgboost/testing/ordinal.py +++ b/python-package/xgboost/testing/ordinal.py @@ -76,6 +76,7 @@ def run_cat_container(device: Literal["cpu", "cuda"]) -> None: Xy = DMatrix(df, enable_categorical=True) +# pylint: disable=too-many-statements def run_cat_container_mixed(device: Literal["cpu", "cuda"]) -> None: """Run checks with mixed types.""" import pandas as pd @@ -84,7 +85,7 @@ def run_cat_container_mixed(device: Literal["cpu", "cuda"]) -> None: is_cudf_cat = _lazy_load_cudf_is_cat() except ImportError: - def is_cudf_cat(dtype: Any) -> bool: + def is_cudf_cat(_: Any) -> bool: return False n_samples = int(2**10) From 31c8c295976c541d17f81870b1755dddcffa648f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 02:17:57 +0800 Subject: [PATCH 06/12] Fix. --- tests/python-gpu/test_from_cudf.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 5565c6bf4b3f..be9958bf547e 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -212,26 +212,27 @@ def test_cudf_categorical(self) -> None: # mixed dtypes X["0"] = X["0"].astype(np.int64) X["2"] = X["2"].astype(np.int64) - df, cat_codes, _, _ = xgb.data._transform_cudf_df( + df, _, _ = xgb.data._transform_cudf_df( X, None, None, enable_categorical=True ) assert X.shape[1] == n_features - assert len(cat_codes) == X.shape[1] - assert not cat_codes[0] - assert not cat_codes[2] + assert isinstance(df.aitfs[0], dict) + assert isinstance(df.aitfs[1], tuple) + assert isinstance(df.aitfs[2], dict) - interfaces_str = xgb.data._cudf_array_interfaces(df, cat_codes) + interfaces_str = df.array_interface() interfaces = json.loads(interfaces_str) assert len(interfaces) == X.shape[1] # test missing value X = cudf.DataFrame({"f0": ["a", "b", np.nan]}) X["f0"] = X["f0"].astype("category") - df, cat_codes, _, _ = xgb.data._transform_cudf_df( + df, _, _ = xgb.data._transform_cudf_df( X, None, None, enable_categorical=True ) - for col in cat_codes: - assert col.has_nulls + for col in df.aitfs: + assert isinstance(col, tuple) + assert "mask" in col[1] y = [0, 1, 2] with pytest.raises(ValueError): From 648589321ba76b4354cd344ff7465f893a3a60fa Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 02:22:24 +0800 Subject: [PATCH 07/12] Windows. --- include/xgboost/data.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 954ffc586006..d82d480c805c 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -43,7 +43,7 @@ enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 }; enum class DataSplitMode : int { kRow = 0, kCol = 1 }; // Forward declaration of the container used by the meta info. -struct CatContainer; +class CatContainer; /** * @brief Meta information about dataset, always sit in memory. From ad268a20da2972e92976894687be5b332612d2a1 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 02:25:51 +0800 Subject: [PATCH 08/12] empty partition. --- src/data/device_adapter.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/data/device_adapter.cu b/src/data/device_adapter.cu index 0d37e2095959..52fcf1880638 100644 --- a/src/data/device_adapter.cu +++ b/src/data/device_adapter.cu @@ -50,7 +50,12 @@ CudfAdapter::CudfAdapter(StringView cuda_arrinf) { this->d_cats_ = this->cats_; // thrust copy CHECK(!columns.empty()); - device_ = DeviceOrd::CUDA(dh::CudaGetPointerDevice(columns.front().data)); + if (device == -1) { + CHECK_EQ(columns.front().Shape<0>(), 0); + device_ = DeviceOrd::CUDA(0); + } else { + device_ = DeviceOrd::CUDA(device); + } CHECK(device_.IsCUDA()); curt::SetDevice(device_.ordinal); From 881474031a3138314958878fd1173052bdc297b4 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 02:38:29 +0800 Subject: [PATCH 09/12] rename. --- python-package/xgboost/_data_utils.py | 10 +++++----- python-package/xgboost/data.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python-package/xgboost/_data_utils.py b/python-package/xgboost/_data_utils.py index ed71414adfd1..9b381facca93 100644 --- a/python-package/xgboost/_data_utils.py +++ b/python-package/xgboost/_data_utils.py @@ -185,7 +185,7 @@ def is_arrow_dict(data: Any) -> TypeGuard["pa.DictionaryArray"]: return lazy_isinstance(data, "pyarrow.lib", "DictionaryArray") -class PdCatAccessor(Protocol): +class DfCatAccessor(Protocol): """Protocol for pandas cat accessor.""" @property @@ -207,7 +207,7 @@ def to_arrow( # pylint: disable=missing-function-docstring def __cuda_array_interface__(self) -> ArrayInf: ... -def _is_df_cat(data: Any) -> TypeGuard[PdCatAccessor]: +def _is_df_cat(data: Any) -> TypeGuard[DfCatAccessor]: # Test pd.Series.cat, not pd.Series return hasattr(data, "categories") and hasattr(data, "codes") @@ -320,7 +320,7 @@ def array_interface_dict(data: np.ndarray) -> ArrayInf: ... @overload def array_interface_dict( - data: PdCatAccessor, + data: DfCatAccessor, ) -> Tuple[StringArray, ArrayInf, Tuple]: ... @@ -331,7 +331,7 @@ def array_interface_dict( def array_interface_dict( # pylint: disable=too-many-locals - data: Union[np.ndarray, PdCatAccessor], + data: Union[np.ndarray, DfCatAccessor], ) -> Union[ArrayInf, Tuple[StringArray, ArrayInf, Optional[Tuple]]]: """Returns an array interface from the input.""" # Handle categorical values @@ -407,7 +407,7 @@ def check_cudf_meta(data: _CudaArrayLikeArg, field: str) -> None: def cudf_cat_inf( - cats: PdCatAccessor, codes: "pd.Series" + cats: DfCatAccessor, codes: "pd.Series" ) -> Tuple[Union[ArrayInf, StringArray], ArrayInf, Tuple]: """Obtain the cuda array interface for cuDF categories.""" cp = import_cupy() diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index ea95df7556d4..4ada55472348 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -24,7 +24,7 @@ from ._data_utils import ( ArrayInf, - PdCatAccessor, + DfCatAccessor, StringArray, TransformedDf, _ensure_np_dtype, @@ -536,14 +536,14 @@ def _lazy_load_pd_floats() -> tuple: return Float32Dtype, Float64Dtype -def pandas_transform_data(data: DataFrame) -> List[Union[np.ndarray, PdCatAccessor]]: +def pandas_transform_data(data: DataFrame) -> List[Union[np.ndarray, DfCatAccessor]]: """Handle categorical dtype and extension types from pandas.""" Float32Dtype, Float64Dtype = _lazy_load_pd_floats() - result: List[Union[np.ndarray, PdCatAccessor]] = [] + result: List[Union[np.ndarray, DfCatAccessor]] = [] np_dtypes = _lazy_has_npdtypes() - def cat_codes(ser: "PdSeries") -> PdCatAccessor: + def cat_codes(ser: "PdSeries") -> DfCatAccessor: return ser.cat def nu_type(ser: "PdSeries") -> np.ndarray: @@ -608,7 +608,7 @@ class PandasTransformed(TransformedDf): """A storage class for transformed pandas DataFrame.""" def __init__( - self, columns: List[Union[np.ndarray, PdCatAccessor, "pa.DictionaryType"]] + self, columns: List[Union[np.ndarray, DfCatAccessor, "pa.DictionaryType"]] ) -> None: self.columns = columns @@ -1067,7 +1067,7 @@ def _lazy_load_cudf_is_bool() -> Callable[[Any], bool]: class CudfTransformed(TransformedDf): """A storage class for transformed cuDF dataframe.""" - def __init__(self, columns: List[Union["PdSeries", PdCatAccessor]]) -> None: + def __init__(self, columns: List[Union["PdSeries", DfCatAccessor]]) -> None: self.columns = columns # Buffers for temporary data that cannot be freed until the data is consumed by # the DMatrix or the booster. From ac9a764c9de05059072923f90170b9a28b93651f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 02:45:41 +0800 Subject: [PATCH 10/12] lint. --- tests/python-gpu/test_from_cudf.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index be9958bf547e..8c263901d94a 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -212,9 +212,7 @@ def test_cudf_categorical(self) -> None: # mixed dtypes X["0"] = X["0"].astype(np.int64) X["2"] = X["2"].astype(np.int64) - df, _, _ = xgb.data._transform_cudf_df( - X, None, None, enable_categorical=True - ) + df, _, _ = xgb.data._transform_cudf_df(X, None, None, enable_categorical=True) assert X.shape[1] == n_features assert isinstance(df.aitfs[0], dict) assert isinstance(df.aitfs[1], tuple) @@ -227,9 +225,7 @@ def test_cudf_categorical(self) -> None: # test missing value X = cudf.DataFrame({"f0": ["a", "b", np.nan]}) X["f0"] = X["f0"].astype("category") - df, _, _ = xgb.data._transform_cudf_df( - X, None, None, enable_categorical=True - ) + df, _, _ = xgb.data._transform_cudf_df(X, None, None, enable_categorical=True) for col in df.aitfs: assert isinstance(col, tuple) assert "mask" in col[1] From aa0607b3bad56c65a3434d1c5fda68ed9ef23ea9 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 02:54:05 +0800 Subject: [PATCH 11/12] Cleanup. --- python-package/xgboost/_data_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python-package/xgboost/_data_utils.py b/python-package/xgboost/_data_utils.py index 9b381facca93..c651fa03c709 100644 --- a/python-package/xgboost/_data_utils.py +++ b/python-package/xgboost/_data_utils.py @@ -288,9 +288,7 @@ def make_array_inf( """Helper for handling categorical codes.""" # Handle cuDF data if hasattr(array, "__cuda_array_interface__"): - inf = array.__cuda_array_interface__ - if "mask" in inf: - inf["mask"] = inf["mask"].__cuda_array_interface__ + inf = cuda_array_interface_dict(array) return inf, None # Other types (like arrow itself) are not yet supported. @@ -416,9 +414,7 @@ def cudf_cat_inf( ) if is_num_idx: cats_ainf = cats.__cuda_array_interface__ - codes_ainf = codes.__cuda_array_interface__ - if "mask" in codes_ainf: - codes_ainf["mask"] = codes_ainf["mask"].__cuda_array_interface__ + codes_ainf = cuda_array_interface_dict(codes) return cats_ainf, codes_ainf, (cats, codes) joffset, jdata, buf = _arrow_cat_inf(cats.to_arrow(), codes) From 7b7bfea76e2690be79777ae14c216269350b7e93 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 6 Mar 2025 03:48:13 +0800 Subject: [PATCH 12/12] return -1. --- src/common/device_helpers.cuh | 3 +++ src/data/device_adapter.cu | 7 ++++--- src/data/simple_dmatrix.cu | 7 ++++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 8e932cdaf8ba..5292edbf3591 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -92,6 +92,9 @@ XGBOOST_DEV_INLINE T atomicAdd(T *addr, T v) { // NOLINT namespace dh { inline int32_t CudaGetPointerDevice(void const *ptr) { + if (!ptr) { + return -1; + } int32_t device = -1; cudaPointerAttributes attr; dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); diff --git a/src/data/device_adapter.cu b/src/data/device_adapter.cu index 52fcf1880638..1462880eca1d 100644 --- a/src/data/device_adapter.cu +++ b/src/data/device_adapter.cu @@ -1,7 +1,7 @@ /** * Copyright 2019-2025, XGBoost Contributors */ -#include "../common/cuda_rt_utils.h" +#include "../common/cuda_rt_utils.h" // for SetDevice, CurrentDevice #include "device_adapter.cuh" namespace xgboost::data { @@ -50,9 +50,10 @@ CudfAdapter::CudfAdapter(StringView cuda_arrinf) { this->d_cats_ = this->cats_; // thrust copy CHECK(!columns.empty()); - if (device == -1) { + if (device < 0) { + // Empty dataset CHECK_EQ(columns.front().Shape<0>(), 0); - device_ = DeviceOrd::CUDA(0); + device_ = DeviceOrd::CUDA(curt::CurrentDevice()); } else { device_ = DeviceOrd::CUDA(device); } diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index b2f51671b12b..1436d982bc29 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -5,8 +5,9 @@ #include // for int32_t, int8_t #include // for make_shared -#include "cat_container.h" // for CatContainer -#include "device_adapter.cuh" // for CurrentDevice +#include "../common/cuda_rt_utils.h" // for CurrentDevice +#include "cat_container.h" // for CatContainer +#include "device_adapter.cuh" // for CurrentDevice #include "simple_dmatrix.cuh" #include "simple_dmatrix.h" #include "xgboost/context.h" // for Context @@ -22,7 +23,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, std::int32_t nthr CHECK(data_split_mode != DataSplitMode::kCol) << "Column-wise data split is currently not supported on the GPU."; auto device = (!adapter->Device().IsCUDA() || adapter->NumRows() == 0) - ? DeviceOrd::CUDA(dh::CurrentDevice()) + ? DeviceOrd::CUDA(curt::CurrentDevice()) : adapter->Device(); CHECK(device.IsCUDA()); dh::safe_cuda(cudaSetDevice(device.ordinal));