Skip to content

Commit

Permalink
fix allgather logic and update unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Mar 19, 2024
1 parent dd6adde commit 9567e67
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
2 changes: 2 additions & 0 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
collective::Broadcast(&size, sizeof(std::size_t), 0);

if (info.IsSecure() && is_gpair) {

// Under secure mode, gpairs will be processed to vector and encrypt
// information only available on rank 0
if (collective::GetRank() == 0) {
Expand All @@ -120,6 +121,7 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
// make broadcast call on the prepared data buffer
// (to local gRPC handler for further encryption)
// collective::Broadcast(gh_buffer, size_of_buffer, 0);

result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
} else {
Expand Down
53 changes: 31 additions & 22 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ class HistogramBuilder {
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection,
common::Span<GradientPair const> gpair_h, bool force_read_by_column) {

// Print out all kinds if information for interface integration
if ((collective::GetRank() == 0)) {
if (is_distributed_ && is_col_split_ && is_secure_ && (collective::GetRank() == 0)) {
std::cout << "--------------Node Hist----------------" << std::endl;
std::cout << "Current samples on nodes: " << std::endl;
// print info on all nodes
Expand Down Expand Up @@ -111,10 +112,10 @@ class HistogramBuilder {
}
std::cout << std::endl;
std::cout << "------------------------------" << std::endl;
}
// Call the interface to transmit the row set collection and gidx to the secure worker
if ((collective::GetRank() == 0)) {
std::cout << "---------------CALL interface to transmit row & gidx------------" << std::endl;
// Call the interface to transmit the row set collection and gidx to the secure worker
if ((collective::GetRank() == 0)) {
std::cout << "---------------CALL interface to transmit row & gidx------------" << std::endl;
}
}

// Parallel processing by nodes and data in each node
Expand Down Expand Up @@ -238,19 +239,22 @@ class HistogramBuilder {
// Under secure vertical mode, we perform allgather to get the global histogram.
// note that only Label Owner needs the global histogram
CHECK(!nodes_to_build.empty());

// Front item of nodes_to_build
auto first_nidx = nodes_to_build.front();
// *2 because we have a pair of g and h for each histogram item
std::size_t n = n_total_bins * nodes_to_build.size() * 2;

// Use AllGather to collect the histogram entries from all nodes
// allocate memory for the received entries as a flat vector
std::vector<double> hist_flat;
hist_flat.resize(n);
// iterate through the nodes_to_build
auto it = reinterpret_cast<double *>(this->hist_[nodes_to_build.front()].data());
auto hist_size = this->hist_[nodes_to_build.front()].size();
auto it = reinterpret_cast<double *>(this->hist_[first_nidx].data());
auto hist_size = this->hist_[first_nidx].size();
for (size_t i = 0; i < n; i++) {
// get item with iterator
auto item = *it;
hist_flat[i] = item;
double item = *it;
hist_flat.push_back(item);
it++;
}

Expand All @@ -260,22 +264,27 @@ class HistogramBuilder {
if (collective::GetRank() == 0) {
std::cout << "---------------CALL Interface for processing-------------- " << std::endl;
}
// Update histogram for data owner

// Update histogram for label owner
if (collective::GetRank() == 0) {
// skip rank 0, as local hist already contains its own entries
// reposition iterator to the beginning of the vector
it = reinterpret_cast<double *>(this->hist_[nodes_to_build.front()].data());
for (auto rank_idx = 1; rank_idx < hist_entries.size()/n; rank_idx++) {
// iterate through the flat vector
for (size_t i = 0; i < n; i++) {
auto flat_idx = rank_idx * n + i;
auto hist_item = hist_entries[flat_idx];
// update the global histogram with the received entries
*it += hist_item;
it++;
// iterator of the beginning of the vector
auto it = reinterpret_cast<double *>(this->hist_[first_nidx].data());
// iterate through the hist vector of the label owner
for (size_t i = 0; i < n; i++) {
// skip rank 0, as local hist already contains its own entries
// get the sum of the entries from other ranks
double hist_sum = 0.0;
for (int rank_idx = 1; rank_idx < hist_entries.size()/n; rank_idx++) {
int flat_idx = rank_idx * n + i;
hist_sum += hist_entries.at(flat_idx);
}
// add other parties' sum to rank 0's record
// to get the global histogram
*it += hist_sum;
it++;
}
}

}

common::BlockedSpace2d const &subspace =
Expand Down
14 changes: 12 additions & 2 deletions tests/cpp/tree/hist/test_histogram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,18 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
collective::Allreduce<collective::Operation::kSum>(&grad, 1);
collective::Allreduce<collective::Operation::kSum>(&hess, 1);
}
ASSERT_NEAR(grad, histogram.Histogram()[nid][i].GetGrad(), kEps);
ASSERT_NEAR(hess, histogram.Histogram()[nid][i].GetHess(), kEps);
if (is_distributed && !is_col_split) {
// row split, all party holds the same data
ASSERT_NEAR(grad, histogram.Histogram()[nid][i].GetGrad(), kEps);
ASSERT_NEAR(hess, histogram.Histogram()[nid][i].GetHess(), kEps);
}
if (is_distributed && is_col_split && is_secure) {
// secure col split, only rank 0 holds the global histogram
if (collective::GetRank() == 0) {
ASSERT_NEAR(grad, histogram.Histogram()[nid][i].GetGrad(), kEps);
ASSERT_NEAR(hess, histogram.Histogram()[nid][i].GetHess(), kEps);
}
}
}
}

Expand Down

0 comments on commit 9567e67

Please sign in to comment.