Skip to content

Commit

Permalink
revise and simplify template code
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Mar 6, 2024
1 parent d91be10 commit dd60317
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 78 deletions.
38 changes: 38 additions & 0 deletions CMakeCache.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# This is the CMakeCache file.
# For build in directory: /media/ziyuexu/Research/Experiment/SecureFedXGBoost/XGBoost/xgboost_SecureBoostP2
# It was generated by CMake: /usr/bin/cmake
# You can edit this file to change values found and used by cmake.
# If you do not want to change any of the values, simply exit the editor.
# If you do want to change a value, simply edit, save, and exit the editor.
# The syntax for the file is as follows:
# KEY:TYPE=VALUE
# KEY is the name of a variable in the cache.
# TYPE is a hint to GUIs for the type of VALUE, DO NOT EDIT TYPE!.
# VALUE is the current value for the KEY.

########################
# EXTERNAL cache entries
########################


########################
# INTERNAL cache entries
########################

//This is the directory where this CMakeCache.txt was created
CMAKE_CACHEFILE_DIR:INTERNAL=/media/ziyuexu/Research/Experiment/SecureFedXGBoost/XGBoost/xgboost_SecureBoostP2
//Major version of cmake used to create the current loaded cache
CMAKE_CACHE_MAJOR_VERSION:INTERNAL=3
//Minor version of cmake used to create the current loaded cache
CMAKE_CACHE_MINOR_VERSION:INTERNAL=22
//Patch version of cmake used to create the current loaded cache
CMAKE_CACHE_PATCH_VERSION:INTERNAL=1
//Path to CMake executable.
CMAKE_COMMAND:INTERNAL=/usr/bin/cmake
//Path to cpack program executable.
CMAKE_CPACK_COMMAND:INTERNAL=/usr/bin/cpack
//Path to ctest program executable.
CMAKE_CTEST_COMMAND:INTERNAL=/usr/bin/ctest
//Path to CMake installation.
CMAKE_ROOT:INTERNAL=/usr/share/cmake-3.22

4 changes: 4 additions & 0 deletions CMakeFiles/clion-log.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Cannot generate into /media/ziyuexu/Research/Experiment/SecureFedXGBoost/XGBoost/xgboost_SecureBoostP2
It is already used for unknown project

Please either delete it manually or select another generation directory
1 change: 1 addition & 0 deletions CMakeFiles/cmake.check_cache
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# This file is generated by cmake for dependency checking of the CMakeCache.txt file
100 changes: 31 additions & 69 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
* @param result The HostDeviceVector storing the results.
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
template <bool is_gpair, typename T, typename Function>
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
Expand All @@ -91,84 +91,46 @@ void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function

std::size_t size{};
if (collective::GetRank() == 0) {
size = result->Size();
size = result->Size();
}
collective::Broadcast(&size, sizeof(std::size_t), 0);

result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
} else {
std::forward<Function>(function)();
}
}

// Same as above, but with encyption on the result
template <typename T, typename Function>
void ApplyWithLabelsEncrypted(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
}

collective::Broadcast(&message, 0);
if (!message.empty()) {
LOG(FATAL) << &message[0];
return;
}

std::size_t size{};
if (collective::GetRank() == 0) {
size = result->Size();
}
collective::Broadcast(&size, sizeof(std::size_t), 0);

// save to vector and encrypt
if (collective::GetRank() == 0) {
// check the max and min value of the result vector
float max_g = std::numeric_limits<float>::min();
float min_g = std::numeric_limits<float>::max();
float max_h = std::numeric_limits<float>::min();
float min_h = std::numeric_limits<float>::max();
std::vector<double> result_vector_g, result_vector_h;
if (info.IsSecure() && is_gpair) {
// Under secure mode, gpairs will be processed to vector and encrypt
// information only available on rank 0
if (collective::GetRank() == 0) {
std::vector<double> vector_g, vector_h;
for (int i = 0; i < size; i++) {
result_vector_g.push_back(result->HostVector()[i].GetGrad());
result_vector_h.push_back(result->HostVector()[i].GetHess());

if (result->HostVector()[i].GetGrad() > max_g) {
max_g = result->HostVector()[i].GetGrad();
}
if (result->HostVector()[i].GetGrad() < min_g) {
min_g = result->HostVector()[i].GetGrad();
}
if (result->HostVector()[i].GetHess() > max_h) {
max_h = result->HostVector()[i].GetHess();
}
if (result->HostVector()[i].GetHess() < min_h) {
min_h = result->HostVector()[i].GetHess();
}
auto gpair = result->HostVector()[i];
// cast from GradientPair to float pointer
auto gpair_ptr = reinterpret_cast<float*>(&gpair);
// save to vector
vector_g.push_back(gpair_ptr[0]);
vector_h.push_back(gpair_ptr[1]);
}
// print 1 sample
//std::cout << " g[0]: " << result_vector_g[0] << " h[0]: " << result_vector_h[0] << std::endl;
// print max and min
//std::cout << "max_g: " << max_g << " min_g: " << min_g << " max_h: " << max_h << " min_h: " << min_h << std::endl;
// provide the vectors to the processor interface

}
// broadcast the encrypted data
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
} else {
// clear text mode
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
}

result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);

/*
// print 1 sample
std::cout << "Rank: " << collective::GetRank() << " after broadcast - g: " << result->HostVector()[0].GetGrad() << " h: " << result->HostVector()[0].GetHess() << std::endl;


if (is_gpair) {
std::cout << "Rank: " << collective::GetRank() << " after broadcast - g: "
<< reinterpret_cast<float*>(&result->HostVector()[0])[0] << " h: "
<< reinterpret_cast<float*>(&result->HostVector()[0])[1] << std::endl;
}
*/
} else {
std::forward<Function>(function)();
std::forward<Function>(function)();
}
}

Expand Down
12 changes: 3 additions & 9 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ class LearnerConfiguration : public Learner {

void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
base_score->Reshape(1);
collective::ApplyWithLabels(info, base_score->Data(),
collective::ApplyWithLabels<false>(info, base_score->Data(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
}
};
Expand Down Expand Up @@ -1472,15 +1472,9 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
// calculate gradient and communicate with or without encryption
if (info.IsSecure()) {
collective::ApplyWithLabelsEncrypted(info, out_gpair->Data(),
// calculate gradient and communicate
collective::ApplyWithLabels<true>(info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
} else {
collective::ApplyWithLabels(info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
}

}

/*! \brief random number transformation seed. */
Expand Down

0 comments on commit dd60317

Please sign in to comment.