Skip to content

Commit

Permalink
Added support for horizontal secure XGBoost
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed Apr 25, 2024
1 parent 2a8f19a commit 9ff2935
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <iostream>
#include <cstring>
#include <cstdint>
#include "./dummy_processor.h"
#include "./mock_processor.h"

const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1
const int64_t kPrefixLen = 24;
Expand All @@ -13,7 +13,7 @@ bool ValidDam(void *buffer, std::size_t size) {
return size >= kPrefixLen && memcmp(buffer, kSignature, strlen(kSignature)) == 0;
}

void* DummyProcessor::ProcessGHPairs(std::size_t *size, const std::vector<double>& pairs) {
void* MockProcessor::ProcessGHPairs(std::size_t *size, const std::vector<double>& pairs) {
*size = kPrefixLen + pairs.size()*10*8; // Assume encrypted size is 10x

int64_t buf_size = *size;
Expand All @@ -39,13 +39,13 @@ void* DummyProcessor::ProcessGHPairs(std::size_t *size, const std::vector<double
}


void* DummyProcessor::HandleGHPairs(std::size_t *size, void *buffer, std::size_t buf_size) {
void* MockProcessor::HandleGHPairs(std::size_t *size, void *buffer, std::size_t buf_size) {
*size = buf_size;
if (!ValidDam(buffer, *size)) {
return buffer;
}

// For dummy, this call is used to set gh_pairs for passive sites
// For mock, this call is used to set gh_pairs for passive sites
if (!active_) {
int8_t *ptr = static_cast<int8_t *>(buffer);
ptr += kPrefixLen;
Expand All @@ -60,7 +60,7 @@ void* DummyProcessor::HandleGHPairs(std::size_t *size, void *buffer, std::size_t
return buffer;
}

void *DummyProcessor::ProcessAggregation(std::size_t *size, std::map<int, std::vector<int>> nodes) {
void *MockProcessor::ProcessAggregation(std::size_t *size, std::map<int, std::vector<int>> nodes) {
int total_bin_size = cuts_.back();
int histo_size = total_bin_size*2;
*size = kPrefixLen + 8*histo_size*nodes.size();
Expand Down Expand Up @@ -93,15 +93,15 @@ void *DummyProcessor::ProcessAggregation(std::size_t *size, std::map<int, std::v
return buf;
}

std::vector<double> DummyProcessor::HandleAggregation(void *buffer, std::size_t buf_size) {
std::vector<double> MockProcessor::HandleAggregation(void *buffer, std::size_t buf_size) {
std::vector<double> result = std::vector<double>();

int8_t* ptr = static_cast<int8_t *>(buffer);
auto rest_size = buf_size;

while (rest_size > kPrefixLen) {
if (!ValidDam(ptr, rest_size)) {
continue;
break;
}
int64_t *size_ptr = reinterpret_cast<int64_t *>(ptr + 8);
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen);
Expand All @@ -113,3 +113,62 @@ std::vector<double> DummyProcessor::HandleAggregation(void *buffer, std::size_t

return result;
}

void* MockProcessor::ProcessHistograms(std::size_t *size, const std::vector<double>& histograms) {
*size = kPrefixLen + histograms.size()*10*8; // Assume encrypted size is 10x

int64_t buf_size = *size;
// This memory needs to be freed
char *buf = static_cast<char *>(malloc(buf_size));
memcpy(buf, kSignature, strlen(kSignature));
memcpy(buf + 8, &buf_size, 8);
memcpy(buf + 16, &kDataTypeAggregatedHisto, 8);

// Simulate encryption by duplicating value 10 times
int index = kPrefixLen;
for (auto value : histograms) {
for (std::size_t i = 0; i < 10; i++) {
memcpy(buf+index, &value, 8);
index += 8;
}
}

return buf;
}

std::vector<double> MockProcessor::HandleHistograms(void *buffer, std::size_t buf_size) {
std::vector<double> result = std::vector<double>();

int8_t* ptr = static_cast<int8_t *>(buffer);
auto rest_size = buf_size;

while (rest_size > kPrefixLen) {
if (!ValidDam(ptr, rest_size)) {
break;
}
int64_t *size_ptr = reinterpret_cast<int64_t *>(ptr + 8);
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen);
auto array_size = (*size_ptr - kPrefixLen)/8;
auto empty = result.empty();
if (!empty) {
if (result.size() != array_size / 10) {
std::cout << "Histogram size doesn't match " << result.size() << " != " << array_size << std::endl;
return result;
}
}

for (std::size_t i = 0; i < array_size/10; i++) {
auto value = array_start[i*10];
if (empty) {
result.push_back(value);
} else {
result[i] += value;
}
}

rest_size -= *size_ptr;
ptr = ptr + *size_ptr;
}

return result;
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
// Data type definition
const int64_t kDataTypeGHPairs = 1;
const int64_t kDataTypeHisto = 2;
const int64_t kDataTypeAggregatedHisto = 3;

class DummyProcessor: public processing::Processor {
class MockProcessor: public processing::Processor {
private:
bool active_ = false;
const std::map<std::string, std::string> *params_{nullptr};
Expand Down Expand Up @@ -50,4 +51,8 @@ class DummyProcessor: public processing::Processor {
void *ProcessAggregation(size_t *size, std::map<int, std::vector<int>> nodes) override;

std::vector<double> HandleAggregation(void *buffer, size_t buf_size) override;

void *ProcessHistograms(size_t *size, const std::vector<double>& histograms) override;

std::vector<double> HandleHistograms(void *buffer, size_t buf_size) override;
};
28 changes: 22 additions & 6 deletions src/processing/processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace processing {

const char kLibraryPath[] = "LIBRARY_PATH";
const char kDummyProcessor[] = "dummy";
const char kMockProcessor[] = "mock";
const char kLoadFunc[] = "LoadProcessor";

/*! \brief An processor interface to handle tasks that require external library through plugins */
Expand Down Expand Up @@ -76,7 +76,7 @@ class Processor {
* \param size The output buffer size
* \param nodes Map of node and the rows belong to this node
*
* \return The encoded buffer to be sent via AllGather
* \return The encoded buffer to be sent via AllGatherV
*/
virtual void *ProcessAggregation(size_t *size, std::map<int, std::vector<int>> nodes) = 0;

Expand All @@ -90,16 +90,32 @@ class Processor {
* site1_node1, site1_node2 site1_node3, site2_node1, site2_node2, site2_node3
*/
virtual std::vector<double> HandleAggregation(void *buffer, size_t buf_size) = 0;

/*!
* \brief Prepare histograms for further processing
*
* \param size The output buffer size
* \param histograms Flattened array of histograms for all features
*
* \return The encoded buffer to be sent via AllGatherV
*/
virtual void *ProcessHistograms(size_t *size, const std::vector<double>& histograms) = 0;

/*!
* \brief Handle processed histograms
*
* \param buffer Buffer from allgatherV
* \param buf_size The size of the buffer
*
* \return A flattened vector of histograms for all features
*/
virtual std::vector<double> HandleHistograms(void *buffer, size_t buf_size) = 0;
};

class ProcessorLoader {
private:
std::map<std::string, std::string> params;
#if defined(_WIN32)
HMODULE handle_ = NULL;
#else
void *handle_ = NULL;
#endif

public:
ProcessorLoader(): params{} {}
Expand Down
18 changes: 8 additions & 10 deletions src/processing/processor_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,28 @@

#include <iostream>

#if defined(_WIN32)
#if defined(_WIN32) || defined(_WIN64)
#include <windows.h>
#else
#include <dlfcn.h>
#endif

#include "./processor.h"
#include "plugins/dummy_processor.h"
#include "plugins/mock_processor.h"

namespace processing {
using LoadFunc = Processor *(const char *);

Processor* ProcessorLoader::load(const std::string& plugin_name) {
// Dummy processor for unit testing without loading a shared library
if (plugin_name == kDummyProcessor) {
return new DummyProcessor();
if (plugin_name == kMockProcessor) {
return new MockProcessor();
}

auto lib_name = "libproc_" + plugin_name;

auto extension =
#if defined(_WIN32)
#if defined(_WIN32) || defined(_WIN64)
".dll";
#elif defined(__APPLE__) || defined(__MACH__)
".dylib";
Expand All @@ -46,19 +46,18 @@ namespace processing {
lib_path = p + lib_file_name;
}

#if defined(_WIN32)
HMODULE handle_ = LoadLibrary(lib_path.c_str());
#if defined(_WIN32) || defined(_WIN64)
handle_ = reinterpret_cast<void *>(LoadLibrary(lib_path.c_str()));
if (!handle_) {
std::cerr << "Failed to load the dynamic library" << std::endl;
return NULL;
}

void* func_ptr = GetProcAddress(handle_, kLoadFunc);
void* func_ptr = reinterpret_cast<void *>(GetProcAddress((HMODULE)handle_, kLoadFunc));
if (!func_ptr) {
std::cerr << "Failed to find loader function." << std::endl;
return NULL;
}

#else
handle_ = dlopen(lib_path.c_str(), RTLD_LAZY);
if (!handle_) {
Expand All @@ -70,7 +69,6 @@ namespace processing {
std::cerr << "Failed to find loader function: " << dlerror() << std::endl;
return NULL;
}

#endif

auto func = reinterpret_cast<LoadFunc *>(func_ptr);
Expand Down
Loading

0 comments on commit 9ff2935

Please sign in to comment.