Skip to content

Commit

Permalink
Update the processor functions according to new processor implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Apr 12, 2024
1 parent 2997cf7 commit 3a1f9ac
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 18 deletions.
11 changes: 6 additions & 5 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
vector_gh.push_back(gpair_ptr[1]);
}
// provide the vectors to the processor interface
auto buf = processor_instance->ProcessGHPairs(vector_gh);
buffer_size = buf.size();
buffer = reinterpret_cast<std::int8_t *>(buf.data());
size_t size;
auto buf = processor_instance->ProcessGHPairs(size, vector_gh);
buffer_size = size;
buffer = reinterpret_cast<std::int8_t *>(buf);
}

// broadcast the buffer size for other ranks to prepare
Expand All @@ -129,8 +130,8 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
collective::Broadcast(buffer, buffer_size, 0);

// call HandleGHPairs
xgboost::common::Span<int8_t> buf = xgboost::common::Span<int8_t>(buffer, buffer_size);
processor_instance->HandleGHPairs(buf);
size_t size;
processor_instance->HandleGHPairs(size, buffer, buffer_size);
} else {
// clear text mode, broadcast the data directly
result->Resize(size);
Expand Down
15 changes: 7 additions & 8 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ const char* kMaxDeltaStepDefaultValue = "0.7";
} // anonymous namespace

DECLARE_FIELD_ENUM_CLASS(xgboost::MultiStrategy);
xgboost::processing::Processor *processor_instance;
processing::Processor *processor_instance;
namespace xgboost {
Learner::~Learner() = default;
namespace {
Expand Down Expand Up @@ -493,13 +493,12 @@ class LearnerConfiguration : public Learner {

this->ConfigureMetrics(args);

xgboost::processing::ProcessorLoader loader;
processor_instance = loader.load("dummy");
if (collective::GetRank() == 0) {
processor_instance->Initialize(true, {});
} else {
processor_instance->Initialize(false, {});
}
std::map<std::string, std::string> loader_params = {{"LIBRARY_PATH", "/tmp"}};
std::map<std::string, std::string> proc_params = {};
auto plugin_name = "dummy";
processing::ProcessorLoader loader(loader_params);
processor_instance = loader.load(plugin_name);
processor_instance->Initialize(collective::GetRank() == 0, proc_params);

this->need_configuration_ = false;
if (ctx_.validate_parameters) {
Expand Down
31 changes: 26 additions & 5 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,33 @@ class HistogramBuilder {
common::RowSetCollection const &row_set_collection,
common::Span<GradientPair const> gpair_h, bool force_read_by_column) {
if (is_distributed_ && is_col_split_ && is_secure_) {
// Call the interface to transmit the row set collection and gidx to the secure worker
// Call the interface to transmit gidx information to the secure worker
// for encrypted histogram compuation
processor_instance->InitAggregationContext(gidx);
auto slots = std::vector<int>();
auto num_rows = row_set_collection[0].Size();
auto cuts = gidx.Cuts().Ptrs();
for (int row = 0; row < num_rows; row++) {
for (int f = 0; f < cuts.size()-1; f++) {
auto slot = gidx.GetGindex(row, f);
slots.push_back(slot);
}
}
processor_instance->InitAggregationContext(cuts,slots);
// Further use the row set collection info to
// get the encrypted histogram from the secure worker
hist_data = processor_instance->ProcessAggregation(nodes_to_build, row_set_collection);
auto node_map = std::map<int, std::vector<int>>();
for (auto node : nodes_to_build) {
auto rows = std::vector<int>();
auto elem = row_set_collection[node];
for (auto it = elem.begin; it != elem.end; ++it) {
auto row_id = *it;
rows.push_back(row_id);
}
node_map.insert({node, rows});
}
size_t buf_size;
auto buf = processor_instance->ProcessAggregation(buf_size, node_map);
hist_data = xgboost::common::Span<std::int8_t>(static_cast<std::int8_t *>(buf), buf_size);
} else {
// Parallel processing by nodes and data in each node
common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
Expand Down Expand Up @@ -218,8 +240,7 @@ class HistogramBuilder {
hist_data.data() + hist_data.size());
auto hist_entries = collective::Allgather(hist_vec);
// Call interface here to post-process the messages
auto hist_span = common::Span<std::int8_t>(hist_entries.data(), hist_entries.size());
std::vector<double> hist_aggr = processor_instance->HandleAggregation(hist_span);
std::vector<double> hist_aggr = processor_instance->HandleAggregation(hist_entries.data(), hist_entries.size());

// Update histogram for label owner
if (collective::GetRank() == 0) {
Expand Down

0 comments on commit 3a1f9ac

Please sign in to comment.