Skip to content

Commit e920125

Browse files
committed
Small cleanup to the proxy dmatrix.
- Remove the auto-dispatch for CUDA methods.
1 parent 86a9809 commit e920125

File tree

13 files changed

+70
-73
lines changed

13 files changed

+70
-73
lines changed

src/c_api/c_api.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ XGB_DLL int XGProxyDMatrixSetDataCudaArrayInterface(DMatrixHandle handle,
431431
CHECK(p_m);
432432
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
433433
CHECK(m) << "Current DMatrix type does not support set data.";
434-
m->SetCUDAArray(c_interface_str);
434+
m->SetCudaArray(c_interface_str);
435435
API_END();
436436
}
437437

@@ -443,19 +443,19 @@ XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle, char const *
443443
CHECK(p_m);
444444
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
445445
CHECK(m) << "Current DMatrix type does not support set data.";
446-
m->SetCUDAArray(c_interface_str);
446+
m->SetCudaColumnar(c_interface_str);
447447
API_END();
448448
}
449449

450-
XGB_DLL int XGProxyDMatrixSetDataColumnar(DMatrixHandle handle, char const *c_interface_str) {
450+
XGB_DLL int XGProxyDMatrixSetDataColumnar(DMatrixHandle handle, char const *data) {
451451
API_BEGIN();
452452
CHECK_HANDLE();
453-
xgboost_CHECK_C_ARG_PTR(c_interface_str);
453+
xgboost_CHECK_C_ARG_PTR(data);
454454
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
455455
CHECK(p_m);
456456
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
457457
CHECK(m) << "Current DMatrix type does not support set data.";
458-
m->SetColumnarData(c_interface_str);
458+
m->SetColumnar(data);
459459
API_END();
460460
}
461461

@@ -467,7 +467,7 @@ XGB_DLL int XGProxyDMatrixSetDataDense(DMatrixHandle handle, char const *c_inter
467467
CHECK(p_m);
468468
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
469469
CHECK(m) << "Current DMatrix type does not support set data.";
470-
m->SetArrayData(c_interface_str);
470+
m->SetArray(c_interface_str);
471471
API_END();
472472
}
473473

@@ -482,7 +482,7 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr, c
482482
CHECK(p_m);
483483
auto m = static_cast<xgboost::data::DMatrixProxy *>(p_m->get());
484484
CHECK(m) << "Current DMatrix type does not support set data.";
485-
m->SetCSRData(indptr, indices, data, ncol, true);
485+
m->SetCsr(indptr, indices, data, ncol, true);
486486
API_END();
487487
}
488488

@@ -1417,7 +1417,7 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *array_in
14171417
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
14181418
CHECK(proxy) << "Invalid input type for inplace predict.";
14191419
xgboost_CHECK_C_ARG_PTR(array_interface);
1420-
proxy->SetArrayData(array_interface);
1420+
proxy->SetArray(array_interface);
14211421
auto *learner = static_cast<xgboost::Learner *>(handle);
14221422
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
14231423
API_END();
@@ -1438,7 +1438,7 @@ XGB_DLL int XGBoosterPredictFromColumnar(BoosterHandle handle, char const *array
14381438
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
14391439
CHECK(proxy) << "Invalid input type for inplace predict.";
14401440
xgboost_CHECK_C_ARG_PTR(array_interface);
1441-
proxy->SetColumnarData(array_interface);
1441+
proxy->SetColumnar(array_interface);
14421442
auto *learner = static_cast<xgboost::Learner *>(handle);
14431443
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
14441444
API_END();
@@ -1460,7 +1460,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, ch
14601460
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
14611461
CHECK(proxy) << "Invalid input type for inplace predict.";
14621462
xgboost_CHECK_C_ARG_PTR(indptr);
1463-
proxy->SetCSRData(indptr, indices, data, cols, true);
1463+
proxy->SetCsr(indptr, indices, data, cols, true);
14641464
auto *learner = static_cast<xgboost::Learner *>(handle);
14651465
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
14661466
API_END();

src/c_api/c_api.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ int InplacePreidctCUDA(BoosterHandle handle, char const *c_array_interface,
150150
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
151151
CHECK(proxy) << "Invalid input type for inplace predict.";
152152

153-
proxy->SetCUDAArray(c_array_interface);
153+
proxy->SetCudaArray(c_array_interface);
154154

155155
auto config = Json::Load(StringView{c_json_config});
156156
auto *learner = static_cast<Learner *>(handle);

src/data/proxy_dmatrix.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2021-2024, XGBoost Contributors
2+
* Copyright 2021-2025, XGBoost Contributors
33
*/
44

55
#include "proxy_dmatrix.h"
@@ -16,23 +16,23 @@
1616
#endif
1717

1818
namespace xgboost::data {
19-
void DMatrixProxy::SetColumnarData(StringView interface_str) {
19+
void DMatrixProxy::SetColumnar(StringView interface_str) {
2020
std::shared_ptr<ColumnarAdapter> adapter{new ColumnarAdapter{interface_str}};
2121
this->batch_ = adapter;
2222
this->Info().num_col_ = adapter->NumColumns();
2323
this->Info().num_row_ = adapter->NumRows();
2424
this->ctx_.Init(Args{{"device", "cpu"}});
2525
}
2626

27-
void DMatrixProxy::SetArrayData(StringView interface_str) {
27+
void DMatrixProxy::SetArray(StringView interface_str) {
2828
std::shared_ptr<ArrayAdapter> adapter{new ArrayAdapter{interface_str}};
2929
this->batch_ = adapter;
3030
this->Info().num_col_ = adapter->NumColumns();
3131
this->Info().num_row_ = adapter->NumRows();
3232
this->ctx_.Init(Args{{"device", "cpu"}});
3333
}
3434

35-
void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices, char const *c_values,
35+
void DMatrixProxy::SetCsr(char const *c_indptr, char const *c_indices, char const *c_values,
3636
bst_feature_t n_features, bool on_host) {
3737
CHECK(on_host) << "Not implemented on device.";
3838
std::shared_ptr<CSRArrayAdapter> adapter{new CSRArrayAdapter(
@@ -43,6 +43,11 @@ void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices, char
4343
this->ctx_.Init(Args{{"device", "cpu"}});
4444
}
4545

46+
#if !defined(XGBOOST_USE_CUDA)
47+
void DMatrixProxy::SetCudaArray(StringView) { common::AssertGPUSupport(); }
48+
void DMatrixProxy::SetCudaColumnar(StringView) { common::AssertGPUSupport(); }
49+
#endif // !defined(XGBOOST_USE_CUDA)
50+
4651
namespace cuda_impl {
4752
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
4853
std::shared_ptr<DMatrixProxy> proxy, float missing);

src/data/proxy_dmatrix.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#include "proxy_dmatrix.h"
88

99
namespace xgboost::data {
10-
void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
11-
auto adapter{std::make_shared<CudfAdapter>(interface_str)};
10+
void DMatrixProxy::SetCudaColumnar(StringView data) {
11+
auto adapter{std::make_shared<CudfAdapter>(data)};
1212
this->batch_ = adapter;
1313
this->Info().num_col_ = adapter->NumColumns();
1414
this->Info().num_row_ = adapter->NumRows();
@@ -21,8 +21,8 @@ void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
2121
ctx_ = ctx_.MakeCUDA(adapter->Device().ordinal);
2222
}
2323

24-
void DMatrixProxy::FromCudaArray(StringView interface_str) {
25-
auto adapter(std::make_shared<CupyAdapter>(StringView{interface_str}));
24+
void DMatrixProxy::SetCudaArray(StringView data) {
25+
auto adapter(std::make_shared<CupyAdapter>(StringView{data}));
2626
this->batch_ = adapter;
2727
this->Info().num_col_ = adapter->NumColumns();
2828
this->Info().num_row_ = adapter->NumRows();

src/data/proxy_dmatrix.h

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -65,51 +65,39 @@ class DataIterProxy {
6565
};
6666

6767
/**
68-
* @brief A proxy of DMatrix used by external iterator.
68+
* @brief A proxy of DMatrix used by the external iterator.
6969
*/
7070
class DMatrixProxy : public DMatrix {
7171
MetaInfo info_;
7272
std::any batch_;
7373
Context ctx_;
7474

75-
#if defined(XGBOOST_USE_CUDA)
76-
void FromCudaColumnar(StringView interface_str);
77-
void FromCudaArray(StringView interface_str);
78-
#endif // defined(XGBOOST_USE_CUDA)
79-
8075
public:
8176
DeviceOrd Device() const { return ctx_.Device(); }
8277

83-
void SetCUDAArray(char const* c_interface) {
84-
common::AssertGPUSupport();
85-
CHECK(c_interface);
86-
#if defined(XGBOOST_USE_CUDA)
87-
StringView interface_str{c_interface};
88-
Json json_array_interface = Json::Load(interface_str);
89-
if (IsA<Array>(json_array_interface)) {
90-
this->FromCudaColumnar(interface_str);
91-
} else {
92-
this->FromCudaArray(interface_str);
93-
}
94-
#endif // defined(XGBOOST_USE_CUDA)
95-
}
96-
97-
void SetColumnarData(StringView interface_str);
98-
99-
void SetArrayData(StringView interface_str);
100-
void SetCSRData(char const* c_indptr, char const* c_indices, char const* c_values,
78+
/**
79+
* Device setters
80+
*/
81+
void SetCudaColumnar(StringView data);
82+
void SetCudaArray(StringView data);
83+
/**
84+
* Host setters
85+
*/
86+
void SetColumnar(StringView data);
87+
void SetArray(StringView data);
88+
void SetCsr(char const* c_indptr, char const* c_indices, char const* c_values,
10189
bst_feature_t n_features, bool on_host);
10290

10391
MetaInfo& Info() override { return info_; }
10492
MetaInfo const& Info() const override { return info_; }
10593
Context const* Ctx() const override { return &ctx_; }
10694

107-
bool EllpackExists() const override { return false; }
108-
bool GHistIndexExists() const override { return false; }
109-
bool SparsePageExists() const override { return false; }
95+
[[nodiscard]] bool EllpackExists() const override { return false; }
96+
[[nodiscard]] bool GHistIndexExists() const override { return false; }
97+
[[nodiscard]] bool SparsePageExists() const override { return false; }
11098

11199
template <typename Page>
112-
BatchSet<Page> NoBatch() {
100+
static BatchSet<Page> NoBatch() {
113101
LOG(FATAL) << "Proxy DMatrix cannot return data batch.";
114102
return BatchSet<Page>(BatchIterator<Page>(nullptr));
115103
}

tests/cpp/data/test_proxy_dmatrix.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
/**
2-
* Copyright 2021-2023, XGBoost contributors
2+
* Copyright 2021-2025, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
55

6-
#include "../../../src/data/adapter.h"
6+
#include <cstddef> // for size_t
7+
#include <vector> // for vector
8+
79
#include "../../../src/data/proxy_dmatrix.h"
810
#include "../helpers.h"
11+
#include "xgboost/host_device_vector.h" // for HostDeviceVector
912

1013
namespace xgboost::data {
1114
TEST(ProxyDMatrix, HostData) {
1215
DMatrixProxy proxy;
13-
size_t constexpr kRows = 100, kCols = 10;
16+
std::size_t constexpr kRows = 100, kCols = 10;
1417
std::vector<HostDeviceVector<float>> label_storage(1);
1518

1619
HostDeviceVector<float> storage;
1720
auto data =
1821
RandomDataGenerator(kRows, kCols, 0.5).Device(FstCU()).GenerateArrayInterface(&storage);
1922

20-
proxy.SetArrayData(data.c_str());
23+
proxy.SetArray(data.c_str());
2124

2225
auto n_samples = HostAdapterDispatch(&proxy, [](auto const &value) { return value.Size(); });
2326
ASSERT_EQ(n_samples, kRows);

tests/cpp/data/test_proxy_dmatrix.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
/**
2-
* Copyright 2020-2023 XGBoost contributors
2+
* Copyright 2020-2025, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
55
#include <xgboost/host_device_vector.h>
66

7-
#include <any> // for any_cast
8-
#include <memory>
7+
#include <any> // for any_cast
8+
#include <memory> // for shared_ptr
9+
#include <vector> // for vector
910

1011
#include "../../../src/data/device_adapter.cuh"
1112
#include "../../../src/data/proxy_dmatrix.h"
1213
#include "../helpers.h"
14+
#include "xgboost/host_device_vector.h" // for HostDeviceVector
1315

1416
namespace xgboost::data {
1517
TEST(ProxyDMatrix, DeviceData) {
@@ -23,7 +25,7 @@ TEST(ProxyDMatrix, DeviceData) {
2325
.GenerateColumnarArrayInterface(&label_storage);
2426

2527
DMatrixProxy proxy;
26-
proxy.SetCUDAArray(data.c_str());
28+
proxy.SetCudaArray(data.c_str());
2729
proxy.SetInfo("label", labels.c_str());
2830

2931
ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr<CupyAdapter>));
@@ -35,7 +37,7 @@ TEST(ProxyDMatrix, DeviceData) {
3537
data = RandomDataGenerator(kRows, kCols, 0)
3638
.Device(FstCU())
3739
.GenerateColumnarArrayInterface(&columnar_storage);
38-
proxy.SetCUDAArray(data.c_str());
40+
proxy.SetCudaArray(data.c_str());
3941
ASSERT_EQ(proxy.Adapter().type(), typeid(std::shared_ptr<CudfAdapter>));
4042
ASSERT_EQ(std::any_cast<std::shared_ptr<CudfAdapter>>(proxy.Adapter())->NumRows(), kRows);
4143
ASSERT_EQ(std::any_cast<std::shared_ptr<CudfAdapter>>(proxy.Adapter())->NumColumns(), kCols);

tests/cpp/gbm/test_gblinear.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023, XGBoost Contributors
2+
* Copyright 2023-2025, XGBoost Contributors
33
*/
44
#include <gtest/gtest.h>
55
#include <xgboost/global_config.h> // for GlobalConfigThreadLocalStore

tests/cpp/gbm/test_gbtree.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2024, XGBoost contributors
2+
* Copyright 2019-2025, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
55
#include <xgboost/context.h>
@@ -401,9 +401,9 @@ class Dart : public testing::TestWithParam<char const*> {
401401
HostDeviceVector<float>* inplace_predts;
402402
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy{}};
403403
if (ctx.IsCUDA()) {
404-
x->SetCUDAArray(array_str.c_str());
404+
x->SetCudaArray(array_str.c_str());
405405
} else {
406-
x->SetArrayData(array_str.c_str());
406+
x->SetArray(array_str.c_str());
407407
}
408408
learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
409409
&inplace_predts, 0, 0);
@@ -628,7 +628,7 @@ TEST(GBTree, PredictRange) {
628628
HostDeviceVector<float> raw_storage;
629629
auto raw = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateArrayInterface(&raw_storage);
630630
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy{}};
631-
x->SetArrayData(raw.data());
631+
x->SetArray(raw.data());
632632

633633
HostDeviceVector<float>* out_predt;
634634
learner->InplacePredict(x, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),

tests/cpp/gbm/test_gbtree.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2023, XGBoost contributors
2+
* Copyright 2023-2025, XGBoost contributors
33
*/
44
#include <xgboost/context.h> // for Context
55
#include <xgboost/learner.h> // for Learner
@@ -8,7 +8,6 @@
88
#include <limits> // for numeric_limits
99
#include <memory> // for shared_ptr
1010
#include <string> // for string
11-
#include <thread> // for thread
1211

1312
#include "../../../src/data/adapter.h" // for ArrayAdapter
1413
#include "../../../src/data/device_adapter.cuh" // for CupyAdapter
@@ -50,9 +49,9 @@ void TestInplaceFallback(Context const* ctx) {
5049
std::shared_ptr<DMatrix> p_m{new data::DMatrixProxy};
5150
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
5251
if (data_ordinal.IsCPU()) {
53-
proxy->SetArrayData(StringView{X});
52+
proxy->SetArray(StringView{X});
5453
} else {
55-
proxy->SetCUDAArray(X.c_str());
54+
proxy->SetCudaArray(X.c_str());
5655
}
5756

5857
HostDeviceVector<float>* out_predt{nullptr};

0 commit comments

Comments
 (0)