Skip to content

Commit

Permalink
secure vertical GPU fully functional
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Jul 31, 2024
1 parent e42faaa commit 26aaded
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 91 deletions.
27 changes: 0 additions & 27 deletions plugin/federated/federated_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,6 @@
namespace xgboost::collective {
void FederatedPluginMock::Reset(common::Span<std::uint32_t const> cutptrs,
common::Span<std::int32_t const> bin_idx) {

//print some contents of cutptrs and bin_idx
std::cout << "cutptrs.size() = " << cutptrs.size() << std::endl;
for (int i = 0; i < cutptrs.size(); i++) {
std::cout << cutptrs[i] << " ";
}
std::cout << std::endl;

std::cout << "bin_idx.size() = " << bin_idx.size() << std::endl;
for (int i = 0; i < 3; i++) {
std::cout << bin_idx[i] << " ";
}
std::cout << std::endl;

this->cuts_.resize(cutptrs.size());
std::copy_n(cutptrs.data(), cutptrs.size(), this->cuts_.data());

Expand Down Expand Up @@ -84,19 +70,6 @@ void FederatedPluginMock::Reset(common::Span<std::uint32_t const> cutptrs,
auto hist_buffer = common::Span<double>{hist_plain_};
std::fill_n(hist_buffer.data(), hist_buffer.size(), 0.0);

// print some contents of rowptrs

std::cout << "rowptrs.size() = " << rowptrs.size() << std::endl;
for (int i = 0; i < rowptrs.size(); i++) {
std::cout << sizes[i] << std::endl;
common::Span row_indices{rowptrs[i], rowptrs[i] + sizes[i]};
for (int j = 0; j < 5; j++) {
std::cout << row_indices[j] << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;

CHECK_EQ(rowptrs.size(), sizes.size());
CHECK_EQ(nids.size(), sizes.size());
auto gpair = common::RestoreType<GradientPair const>(common::Span<std::uint8_t>{grad_});
Expand Down
11 changes: 7 additions & 4 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,16 @@ void BroadcastGradient(Context const* ctx, MetaInfo const& info, GradFn&& grad_f
SafeColl(rc);
// Pass the gradient to the plugin
fed.EncryptionPlugin()->SyncEncryptedGradient(encrypted);

// !!!Temporarily solution
// This step is needed for memory allocation in the case of vertical secure GPU
// make out_gpair data value to all zero to avoid information leak
auto gpair_data = out_gpair->Data();
gpair_data->Fill(GradientPair{0.0f, 0.0f});
ApplyWithLabels(ctx, info, gpair_data, [&] { grad_fn(out_gpair); });
#else
LOG(FATAL) << error::NoFederated();
#endif

// !!!Temporarily turn on regular gradient broadcasting for testing
// encrypted vertical
ApplyWithLabels(ctx, info, out_gpair->Data(), [&] { grad_fn(out_gpair); });
} else {
ApplyWithLabels(ctx, info, out_gpair->Data(), [&] { grad_fn(out_gpair); });
}
Expand Down
118 changes: 58 additions & 60 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "row_partitioner.cuh"
#include "xgboost/base.h"

#include "../../collective/allgather.h" // for AllgatherV

#include "../../common/device_helpers.cuh"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../../plugin/federated/federated_hist.h" // for FederataedHistPolicy
Expand Down Expand Up @@ -361,6 +363,24 @@ void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor con
this->p_impl_->Reset(ctx, feature_groups, force_global_memory);
}

struct ReadMatrixFunction {
EllpackDeviceAccessor matrix;
int kCols;
bst_float* matrix_data_d;
ReadMatrixFunction(EllpackDeviceAccessor matrix, int kCols, bst_float* matrix_data_d)
: matrix(std::move(matrix)), kCols(kCols), matrix_data_d(matrix_data_d) {}

__device__ void operator()(size_t global_idx) {
auto row = global_idx / kCols;
auto col = global_idx % kCols;
auto value = matrix.GetBinIndex(row, col);
if (isnan(value)) {
value = -1;
}
matrix_data_d[global_idx] = value;
}
};

void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
Expand All @@ -378,105 +398,83 @@ void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();

// Transmit matrix to plugin
if(!is_aggr_context_initialized_){
std::cout << "Initialized Plugin Context" << std::endl;
// Get cutptrs
std::vector<uint32_t> h_cuts_ptr(matrix.feature_segments.size());
dh::CopyDeviceSpanToVector(&h_cuts_ptr, matrix.feature_segments);
common::Span<std::uint32_t const> cutptrs = common::Span<std::uint32_t const>(h_cuts_ptr.data(), h_cuts_ptr.size());
std::cout << "cutptrs.size() = " << h_cuts_ptr.size() << std::endl;
for (int i = 0; i < h_cuts_ptr.size(); i++) {
std::cout << h_cuts_ptr[i] << " ";
}
std::cout << std::endl;

// Get bin_idx matrix



//common::Span<std::int32_t const> bin_idx
//*********************************************
//plugin->Reset(h_cuts_ptr, bin_idx);
//*********************************************
auto kRows = matrix.n_rows;
auto kCols = matrix.NumFeatures();
std::vector<int32_t> h_bin_idx(kRows * kCols);
// Access GPU matrix data
thrust::device_vector<bst_float> matrix_d(kRows * kCols);
dh::LaunchN(kRows * kCols, ReadMatrixFunction(matrix, kCols, matrix_d.data().get()));
thrust::copy(matrix_d.begin(), matrix_d.end(), h_bin_idx.begin());
common::Span<std::int32_t const> bin_idx = common::Span<std::int32_t const>(h_bin_idx.data(), h_bin_idx.size());

// Initialize plugin context
plugin->Reset(h_cuts_ptr, h_bin_idx);
is_aggr_context_initialized_ = true;
}


// get row indices from device
std::vector<uint32_t> h_ridx(ridx.size());
dh::CopyDeviceSpanToVector(&h_ridx, ridx);

// wrap info following plugin expectations
// necessary conversions to fit plugin expectations
std::vector<uint64_t> h_ridx_64(ridx.size());
for (int i = 0; i < ridx.size(); i++) {
h_ridx_64[i] = h_ridx[i];
}
std::vector<std::uint64_t const *> ptrs(1);
std::vector<std::size_t> sizes(1);
std::vector<bst_node_t> nodes(1);
ptrs[0] = reinterpret_cast<std::uint64_t const *>(h_ridx.data());
sizes[0] = h_ridx.size();
ptrs[0] = reinterpret_cast<std::uint64_t const *>(h_ridx_64.data());
sizes[0] = h_ridx_64.size();
nodes[0] = 0;

// Transmit row indices to plugin and get encrypted histogram
std::cout << "Building encrypted histograms with row indices " << std::endl;
//*********************************************
//auto hist_data = plugin->BuildEncryptedHistVert(ptrs, sizes, nodes);
//*********************************************
auto hist_data = plugin->BuildEncryptedHistVert(ptrs, sizes, nodes);

// Perform AllGather
std::cout << "Allgather histograms" << std::endl;

//*********************************************
/*
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx_, linalg::MakeVec(hist_data_), &recv_segments, &hist_entries));
collective::AllgatherV(ctx, linalg::MakeVec(hist_data), &recv_segments, &hist_entries));

// Call the plugin here to get the resulting histogram. Histogram from all workers are
// gathered to the label owner.
common::Span<double> hist_aggr =
plugin_->SyncEncryptedHistVert(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));
*/
//*********************************************





plugin->SyncEncryptedHistVert(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// !!!Temporarily turn on regular histogram building for testing
// compute local histograms
this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups, gpair, ridx, histogram, rounding);

// Further histogram sync process - simulated with allreduce
// copy histogram data to host
// Post process the AllGathered data
auto world_size = collective::GetWorldSize();
std::vector<GradientPairInt64> host_histogram(histogram.size());
dh::CopyDeviceSpanToVector(&host_histogram, histogram);
// convert to regular vector
std::vector<std::int64_t> host_histogram_64(histogram.size() * 2);
for (auto i = 0; i < host_histogram.size(); i++) {
host_histogram_64[i * 2] = host_histogram[i].GetQuantisedGrad();
host_histogram_64[i * 2 + 1] = host_histogram[i].GetQuantisedHess();
}
// aggregate histograms in float
auto rc = collective::Allreduce(ctx, &host_histogram_64, collective::Op::kSum);
SafeColl(rc);
// convert back to GradientPairInt64
// only copy to Rank 0, clear other ranks to simulate the plugin scenario
for (auto i = 0; i < host_histogram.size(); i++) {
GradientPairInt64 hist_item(host_histogram_64[i * 2], host_histogram_64[i * 2 + 1]);
GradientPairInt64 hist_item_empty(0, 0);
for (auto i = 0; i < histogram.size(); i++) {
double grad = 0.0;
double hess = 0.0;
for (auto rank = 0; rank < world_size; ++rank) {
auto idx = rank * histogram.size() + i;
grad += hist_aggr[idx * 2];
hess += hist_aggr[idx * 2 + 1];
}
GradientPairPrecise hist_item(grad, hess);
GradientPairPrecise hist_item_empty(0.0, 0.0);
if (collective::GetRank() != 0) {
hist_item = hist_item_empty;
host_histogram[i] = rounding.ToFixedPoint(hist_item_empty);
} else {
host_histogram[i] = hist_item;
host_histogram[i] = rounding.ToFixedPoint(hist_item);
}
}

// copy the aggregated histogram back to GPU memory
// at this point, the histogram contains full information from all parties
dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(),
histogram.size() * sizeof(GradientPairPrecise),
histogram.size() * sizeof(GradientPairInt64),
cudaMemcpyHostToDevice));

}
}
} // namespace xgboost::tree

0 comments on commit 26aaded

Please sign in to comment.