Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Mar 15, 2024
1 parent 8405791 commit db7d518
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
6 changes: 4 additions & 2 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>*
// provide the vectors to the processor interface
// print vector size for rank 1
if (collective::GetRank() == 0) {
std::cout << "DATA size of gpairs: " << vector_gh.size() << std::endl;
std::cout << "-----------Call Interface for gp encryption and broadcast"
<< ", size of gpairs: " << vector_gh.size()
<< " ----------------------" << std::endl;
}
}
// make broadcast call on the prepared data buffer
// (to local gRPC handler for further encryption)
//collective::Broadcast(gh_buffer, size_of_buffer, 0);
// collective::Broadcast(gh_buffer, size_of_buffer, 0);
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
} else {
Expand Down
4 changes: 3 additions & 1 deletion src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,9 @@ class LearnerConfiguration : public Learner {

this->ConfigureMetrics(args);

std::cout<<"configure interface here???????????????"<<std::endl;
if ((collective::GetRank() == 0)) {
std::cout << "configure interface here???????????????" << std::endl;
}

this->need_configuration_ = false;
if (ctx_.validate_parameters) {
Expand Down
50 changes: 29 additions & 21 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,37 +76,46 @@ class HistogramBuilder {
std::vector<bst_node_t> const &nodes_to_build,
common::RowSetCollection const &row_set_collection,
common::Span<GradientPair const> gpair_h, bool force_read_by_column) {
//Print out all kinds if information for interface integration
// Print out all kinds if information for interface integration
if ((collective::GetRank() == 0)) {
std::cout << "--------------Node Hist----------------" << std::endl;
std::cout << "Current samples on nodes: " << std::endl;
// print info on all nodes
for (bst_node_t nit = 0; nit < row_set_collection.Size(); ++nit) {
auto size = row_set_collection[nit].Size();
std::cout << "Node " << nit << " has " << size << " rows." << std::endl;
// print the first and last indexes of the rows with iterator
if (size > 0) {
std::cout << "First index for node " << nit << " is " << *row_set_collection[nit].begin << " and last index is " << *(row_set_collection[nit].end-1) << std::endl;
std::cout << "First index for node " << nit << " is "
<< *row_set_collection[nit].begin << " and last index is "
<< *(row_set_collection[nit].end - 1) << std::endl;
}
}
std::cout << std::endl;
// print info on the nodes to build
for (auto nit = nodes_to_build.begin(); nit != nodes_to_build.end(); ++nit) {
std::cout << "Building local histogram for node ID: " << *nit << " with " << row_set_collection[*nit].Size() << " samples." << std::endl;
std::cout << "Building local histogram for node ID: " << *nit
<< " with " << row_set_collection[*nit].Size()
<< " samples." << std::endl;
}
std::cout << std::endl;
std::cout << "Call interface to transmit the row set collection and gidx to the secure worker." << std::endl;
std::cout << "GHistIndexMatrix will not change: size of the ginidx: " << gidx.index.Size() << std::endl;
std::cout << "GHistIndexMatrix will not change with size " << gidx.index.Size() << std::endl;
auto cut_ptrs = gidx.Cuts().Ptrs();
//auto cut_values = gidx.Cuts().Values();
//std::cout << "size of the cut points: " << cut_ptrs.size() << std::endl;
std::cout << "first sample falls to: [feature_id, slot #]: " << std::endl;
auto cut_values = gidx.Cuts().Values();
// cut points: feature 0 start (0), feature 1 start, feature 2 start, ... feature n start
// cut value: cut for feature 0 slot 0, ..., cut for feature 0 slot m, cut for feature 1 slot 0, ...
std::cout << "size of the cut points and cut values: "
<< cut_ptrs.size() << " " << cut_values.size() << std::endl;
std::cout << "first sample falls to: [feature_id, slot #, slot cutValue]: " << std::endl;
for (auto i = 0; i < cut_ptrs.size()-1; ++i) {
//std::cout << "feature " << i << " first cut at " << cut_ptrs[i] + 1 << " with value " << cut_values[cut_ptrs[i]+1] << "; ";
std::cout << "[" << gidx.GetGindex(0, i) << ", " << i << "] ";
auto slot_number = gidx.GetGindex(0, i);
std::cout << "[" << i << ", " << slot_number << ", "<< cut_values[slot_number] << "] ";
}
std::cout << std::endl;
std::cout << "------------------------------" << std::endl;
}
// Call the interface to transmit the row set collection and gidx to the secure worker
if ((collective::GetRank() == 0)) {
std::cout << "---------------CALL interface to transmit the row and gidx------------" << std::endl;
}

// Parallel processing by nodes and data in each node
common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) {
Expand Down Expand Up @@ -233,11 +242,11 @@ class HistogramBuilder {

// Option 1: in theory the operation is AllGather, but with current system functionality,
// we use AllReduce to simulate the AllGather operation
auto first_nidx = nodes_to_build.front();
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
//auto first_nidx = nodes_to_build.front();
//collective::Allreduce<collective::Operation::kSum>(
// reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);


/*
// Option 2: use AllGather instead of AllReduce
// Collect the histogram entries from all nodes
// allocate memory for the received entries as a flat vector
Expand All @@ -263,7 +272,10 @@ class HistogramBuilder {

// Perform AllGather
auto hist_entries = collective::Allgather(hist_flat);
// Call interface here to post-process the messages
if (collective::GetRank() == 0) {
std::cout << "---------------CALL Interface for post processing-------------- " << std::endl;
}
// Update histogram for data owner
if (collective::GetRank() == 0) {
// skip rank 0, as local hist already contains its own entries
Expand All @@ -274,17 +286,13 @@ class HistogramBuilder {
// iterate through the flat vector
for (size_t i = 0; i < n; i++) {
auto flat_idx = rank_idx * n + i;
// DECRYPT the received entries HERE!!!!!!!!!
auto hist_item = hist_entries[flat_idx];
// update the global histogram with the received entries
*it += hist_item;
it++;
}
}
}
*/


}

common::BlockedSpace2d const &subspace =
Expand Down

0 comments on commit db7d518

Please sign in to comment.