diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 99831478f234..849fcbb8d58b 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -506,17 +506,9 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, XGB_DLL int XGImportRecordBatch(DataIterHandle data_handle, void* ptr_array, void* ptr_schema); XGB_DLL int XGDMatrixCreateFromArrowCallback( - XGDMatrixCallbackNext *next, - float missing, - int nthread, - const char* label_col_name, - const char* label_lb_col_name, - const char* label_ub_col_name, - const char* weight_col_name, - const char* base_margin_col_name, - const char* qid_col_name, - DMatrixHandle *out); - + XGDMatrixCallbackNext *next, char const *json_config, const char *label_col_name, + const char *label_lb_col_name, const char *label_ub_col_name, const char *weight_col_name, + const char *base_margin_col_name, const char *qid_col_name, DMatrixHandle *out); /* * ==========================- End data callback APIs ========================== diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 75a54287e22a..8165a6090636 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -473,16 +473,18 @@ def _is_arrow(data): def _from_arrow( data, - missing, - nthread, + missing: float, + nthread: int, feature_names: Optional[List[str]], feature_types: Optional[List[str]], - enable_categorical: bool + enable_categorical: bool, ) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]: - if not all(pa.types.is_integer(t) or pa.types.is_floating(t) - for t in data.schema.types): + if not all( + pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types + ): raise ValueError( - 'Features in dataset can only be integers or floating point number') + "Features in dataset can only be integers or floating point number" + ) if enable_categorical: raise ValueError("categorical data in datatable is not supported yet.") @@ -490,17 +492,19 @@ def _from_arrow( it = RecordBatchDataIter(rb_iter) next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it.next) handle = ctypes.c_void_p() + + args = {"missing": missing, "nthread": nthread} + config = bytes(json.dumps(args), "utf-8") ret = _LIB.XGDMatrixCreateFromArrowCallback( next_callback, - ctypes.c_float(missing), - ctypes.c_int(nthread), + config, ctypes.POINTER(ctypes.c_char_p)(), ctypes.POINTER(ctypes.c_char_p)(), ctypes.POINTER(ctypes.c_char_p)(), ctypes.POINTER(ctypes.c_char_p)(), ctypes.POINTER(ctypes.c_char_p)(), ctypes.POINTER(ctypes.c_char_p)(), - ctypes.byref(handle) + ctypes.byref(handle), ) _check_call(ret) return (handle, feature_names, feature_types) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index a3d300827211..77eeabffce3b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -432,8 +432,7 @@ XGB_DLL int XGImportRecordBatch(DataIterHandle data_handle, void* ptr_array, voi XGB_DLL int XGDMatrixCreateFromArrowCallback( XGDMatrixCallbackNext *next, - float missing, - int nthread, + char const* json_config, const char* label_col_name, const char* label_lb_col_name, const char* label_ub_col_name, @@ -442,7 +441,10 @@ XGB_DLL int XGDMatrixCreateFromArrowCallback( const char* qid_col_name, DMatrixHandle *out) { API_BEGIN(); - int n_threads = common::OmpGetNumThreads(nthread); + auto config = Json::Load(StringView{json_config}); + auto missing = GetMissing(config); + int32_t n_threads = get(config["nthread"]); + n_threads = common::OmpGetNumThreads(n_threads); data::RecordBatchesIterAdapter adapter(next, n_threads, label_col_name, @@ -452,7 +454,7 @@ XGB_DLL int XGDMatrixCreateFromArrowCallback( base_margin_col_name, qid_col_name); *out = new std::shared_ptr( - DMatrix::Create(&adapter, missing, nthread)); + DMatrix::Create(&adapter, missing, n_threads)); API_END(); } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index c859ae5e7149..47e875f1d78a 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -242,11 +242,10 @@ template SimpleDMatrix::SimpleDMatrix( *adapter, float missing, int nthread); -template<> -SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, - float missing, int nthread) { - constexpr uint64_t default_max = std::numeric_limits::max(); - uint64_t last_group_id = default_max; +template <> +SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread) { + constexpr uint64_t kDefaultMax = std::numeric_limits::max(); + uint64_t last_group_id = kDefaultMax; bst_uint group_size = 0; auto& offset_vec = sparse_page_->offset.HostVector(); auto& data_vec = sparse_page_->data.HostVector(); @@ -353,7 +352,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, // get group for (size_t i = 0; i < batches[i]->Size(); ++i) { const uint64_t cur_group_id = batches[i]->Qid()[i]; - if (last_group_id == default_max || last_group_id != cur_group_id) { + if (last_group_id == kDefaultMax || last_group_id != cur_group_id) { info_.group_ptr_.push_back(group_size); } last_group_id = cur_group_id; @@ -362,7 +361,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, } } } - if (last_group_id != default_max) { + if (last_group_id != kDefaultMax) { if (group_size > info_.group_ptr_.back()) { info_.group_ptr_.push_back(group_size); }