Skip to content

Commit

Permalink
Use JSON for parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 17, 2022
1 parent c83fc07 commit 629fc5b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 31 deletions.
14 changes: 3 additions & 11 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -506,17 +506,9 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
XGB_DLL int XGImportRecordBatch(DataIterHandle data_handle, void* ptr_array, void* ptr_schema);

XGB_DLL int XGDMatrixCreateFromArrowCallback(
XGDMatrixCallbackNext *next,
float missing,
int nthread,
const char* label_col_name,
const char* label_lb_col_name,
const char* label_ub_col_name,
const char* weight_col_name,
const char* base_margin_col_name,
const char* qid_col_name,
DMatrixHandle *out);

XGDMatrixCallbackNext *next, char const *json_config, const char *label_col_name,
const char *label_lb_col_name, const char *label_ub_col_name, const char *weight_col_name,
const char *base_margin_col_name, const char *qid_col_name, DMatrixHandle *out);

/*
* ==========================- End data callback APIs ==========================
Expand Down
22 changes: 13 additions & 9 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,34 +473,38 @@ def _is_arrow(data):

def _from_arrow(
data,
missing,
nthread,
missing: float,
nthread: int,
feature_names: Optional[List[str]],
feature_types: Optional[List[str]],
enable_categorical: bool
enable_categorical: bool,
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
if not all(pa.types.is_integer(t) or pa.types.is_floating(t)
for t in data.schema.types):
if not all(
pa.types.is_integer(t) or pa.types.is_floating(t) for t in data.schema.types
):
raise ValueError(
'Features in dataset can only be integers or floating point number')
"Features in dataset can only be integers or floating point number"
)
if enable_categorical:
raise ValueError("categorical data in datatable is not supported yet.")

rb_iter = iter(data.to_batches())
it = RecordBatchDataIter(rb_iter)
next_callback = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p)(it.next)
handle = ctypes.c_void_p()

args = {"missing": missing, "nthread": nthread}
config = bytes(json.dumps(args), "utf-8")
ret = _LIB.XGDMatrixCreateFromArrowCallback(
next_callback,
ctypes.c_float(missing),
ctypes.c_int(nthread),
config,
ctypes.POINTER(ctypes.c_char_p)(),
ctypes.POINTER(ctypes.c_char_p)(),
ctypes.POINTER(ctypes.c_char_p)(),
ctypes.POINTER(ctypes.c_char_p)(),
ctypes.POINTER(ctypes.c_char_p)(),
ctypes.POINTER(ctypes.c_char_p)(),
ctypes.byref(handle)
ctypes.byref(handle),
)
_check_call(ret)
return (handle, feature_names, feature_types)
Expand Down
10 changes: 6 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,7 @@ XGB_DLL int XGImportRecordBatch(DataIterHandle data_handle, void* ptr_array, voi

XGB_DLL int XGDMatrixCreateFromArrowCallback(
XGDMatrixCallbackNext *next,
float missing,
int nthread,
char const* json_config,
const char* label_col_name,
const char* label_lb_col_name,
const char* label_ub_col_name,
Expand All @@ -442,7 +441,10 @@ XGB_DLL int XGDMatrixCreateFromArrowCallback(
const char* qid_col_name,
DMatrixHandle *out) {
API_BEGIN();
int n_threads = common::OmpGetNumThreads(nthread);
auto config = Json::Load(StringView{json_config});
auto missing = GetMissing(config);
int32_t n_threads = get<Integer const>(config["nthread"]);
n_threads = common::OmpGetNumThreads(n_threads);
data::RecordBatchesIterAdapter adapter(next,
n_threads,
label_col_name,
Expand All @@ -452,7 +454,7 @@ XGB_DLL int XGDMatrixCreateFromArrowCallback(
base_margin_col_name,
qid_col_name);
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, nthread));
DMatrix::Create(&adapter, missing, n_threads));
API_END();
}

Expand Down
13 changes: 6 additions & 7 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,10 @@ template SimpleDMatrix::SimpleDMatrix(
*adapter,
float missing, int nthread);

template<>
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter,
float missing, int nthread) {
constexpr uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max;
template <>
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread) {
constexpr uint64_t kDefaultMax = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = kDefaultMax;
bst_uint group_size = 0;
auto& offset_vec = sparse_page_->offset.HostVector();
auto& data_vec = sparse_page_->data.HostVector();
Expand Down Expand Up @@ -353,7 +352,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter,
// get group
for (size_t i = 0; i < batches[i]->Size(); ++i) {
const uint64_t cur_group_id = batches[i]->Qid()[i];
if (last_group_id == default_max || last_group_id != cur_group_id) {
if (last_group_id == kDefaultMax || last_group_id != cur_group_id) {
info_.group_ptr_.push_back(group_size);
}
last_group_id = cur_group_id;
Expand All @@ -362,7 +361,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter,
}
}
}
if (last_group_id != default_max) {
if (last_group_id != kDefaultMax) {
if (group_size > info_.group_ptr_.back()) {
info_.group_ptr_.push_back(group_size);
}
Expand Down

0 comments on commit 629fc5b

Please sign in to comment.