Skip to content

Commit

Permalink
integration with interface initial attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Mar 22, 2024
1 parent da0f7a6 commit 406cda3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 23 deletions.
12 changes: 2 additions & 10 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,6 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
// Under secure mode, gpairs will be processed to vector and encrypt
// information only available on rank 0

xgboost::processing::ProcessorLoader loader;
auto processor = loader.load("dummy");
if (collective::GetRank() == 0) {
processor->Initialize(true, {});
} else {
processor->Initialize(false, {});
}

std::size_t buffer_size{};
std::int8_t *buffer;
//common::Span<std::int8_t> buffer;
Expand All @@ -128,7 +120,7 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
std::cout << "-----------Call Interface for gp encryption and broadcast"
<< ", size of gpairs: " << vector_gh.size()
<< " ----------------------" << std::endl;
auto buf = processor->ProcessGHPairs(vector_gh);
auto buf = processor_instance->ProcessGHPairs(vector_gh);
buffer_size = buf.size();
buffer = reinterpret_cast<std::int8_t *>(buf.data());
std::cout << "buffer size: " << buffer_size << std::endl;
Expand All @@ -146,7 +138,7 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*

// call HandleGHPairs
xgboost::common::Span<int8_t> buf = xgboost::common::Span<int8_t>(buffer, buffer_size);
processor->HandleGHPairs(buf);
processor_instance->HandleGHPairs(buf);



Expand Down
8 changes: 4 additions & 4 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ const char* kMaxDeltaStepDefaultValue = "0.7";
} // anonymous namespace

DECLARE_FIELD_ENUM_CLASS(xgboost::MultiStrategy);

xgboost::processing::Processor *processor_instance;
namespace xgboost {
Learner::~Learner() = default;
namespace {
Expand Down Expand Up @@ -497,11 +497,11 @@ class LearnerConfiguration : public Learner {
std::cout << "configure interface here???????????????" << std::endl;
}
xgboost::processing::ProcessorLoader loader;
auto processor = loader.load("dummy");
processor_instance = loader.load("dummy");
if (collective::GetRank() == 0) {
processor->Initialize(true, {});
processor_instance->Initialize(true, {});
} else {
processor->Initialize(false, {});
processor_instance->Initialize(false, {});
}


Expand Down
3 changes: 3 additions & 0 deletions src/processing/processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,7 @@ class ProcessorLoader {

void unload();
};

} // namespace xgboost::processing

extern xgboost::processing::Processor *processor_instance;
11 changes: 2 additions & 9 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,8 @@ class HistogramBuilder {
if ((collective::GetRank() == 0)) {
std::cout << "------------CALL interface to transmit row & gidx---------" << std::endl;
}
xgboost::processing::ProcessorLoader loader;
auto processor = loader.load("dummy");
if (collective::GetRank() == 0) {
processor->Initialize(true, {});
} else {
processor->Initialize(false, {});
}
processor->InitAggregationContext(gidx);
processor->ProcessAggregation(nodes_to_build, row_set_collection);
processor_instance->InitAggregationContext(gidx);
processor_instance->ProcessAggregation(nodes_to_build, row_set_collection);
} else {
// Parallel processing by nodes and data in each node
common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
Expand Down

0 comments on commit 406cda3

Please sign in to comment.