Skip to content

Commit

Permalink
Move processor interface init from learner to communicator
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Apr 12, 2024
1 parent 3a1f9ac commit 1107604
Show file tree
Hide file tree
Showing 6 changed files with 379 additions and 9 deletions.
9 changes: 9 additions & 0 deletions src/collective/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_communicator.h"
#include "../processing/processor.h"
processing::Processor *processor_instance;
#endif

namespace xgboost::collective {
Expand Down Expand Up @@ -39,6 +41,13 @@ void Communicator::Init(Json const& config) {
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
communicator_.reset(FederatedCommunicator::Create(config));
std::cout << "!!!!!!!! Communicator Initialization!!!!!!!!!!!!!!!!!!!! " << std::endl;
auto plugin_name = "dummy";
std::map<std::string, std::string> loader_params = {{"LIBRARY_PATH", "/tmp"}};
std::map<std::string, std::string> proc_params = {};
processing::ProcessorLoader loader(loader_params);
processor_instance = loader.load(plugin_name);
processor_instance->Initialize(collective::GetRank() == 0, proc_params);
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
Expand Down
11 changes: 2 additions & 9 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@
#include "xgboost/predictor.h" // for PredictionContainer, PredictionCacheEntry
#include "xgboost/string_view.h" // for operator<<, StringView
#include "xgboost/task.h" // for ObjInfo
#include "processing/processor.h" // for Processor

namespace {
const char* kMaxDeltaStepDefaultValue = "0.7";
} // anonymous namespace

DECLARE_FIELD_ENUM_CLASS(xgboost::MultiStrategy);
processing::Processor *processor_instance;

namespace xgboost {
Learner::~Learner() = default;
namespace {
Expand Down Expand Up @@ -493,13 +493,6 @@ class LearnerConfiguration : public Learner {

this->ConfigureMetrics(args);

std::map<std::string, std::string> loader_params = {{"LIBRARY_PATH", "/tmp"}};
std::map<std::string, std::string> proc_params = {};
auto plugin_name = "dummy";
processing::ProcessorLoader loader(loader_params);
processor_instance = loader.load(plugin_name);
processor_instance->Initialize(collective::GetRank() == 0, proc_params);

this->need_configuration_ = false;
if (ctx_.validate_parameters) {
this->ValidateParameters();
Expand Down
139 changes: 139 additions & 0 deletions src/processing/plugins/dummy_processor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
*/
#include <iostream>
#include "./dummy_processor.h"

using std::vector;
using std::cout;
using std::endl;

const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1
const int64_t kPrefixLen = 24;

bool ValidDam(void *buffer) {
return memcmp(buffer, kSignature, strlen(kSignature)) == 0;
}

void* DummyProcessor::ProcessGHPairs(size_t &size, std::vector<double>& pairs) {
cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl;

size = kPrefixLen + pairs.size()*10*8; // Assume encrypted size is 10x

int64_t buf_size = size;
// This memory needs to be freed
char *buf = static_cast<char *>(calloc(size, 1));
memcpy(buf, kSignature, strlen(kSignature));
memcpy(buf + 8, &buf_size, 8);
memcpy(buf + 16, &processing::kDataTypeGHPairs, 8);

// Simulate encryption by duplicating value 10 times
int index = kPrefixLen;
for (auto value : pairs) {
for (int i = 0; i < 10; i++) {
memcpy(buf+index, &value, 8);
index += 8;
}
}

// Save pairs for future operations
this->gh_pairs_ = new vector<double>(pairs);

return buf;
}


void* DummyProcessor::HandleGHPairs(size_t &size, void *buffer, size_t buf_size) {
cout << "HandleGHPairs called with buffer size: " << buf_size << " Active: " << active_ << endl;

if (!ValidDam(buffer)) {
cout << "Invalid buffer received" << endl;
return buffer;
}

// For dummy, this call is used to set gh_pairs for passive sites
if (!active_) {
int8_t *ptr = static_cast<int8_t *>(buffer);
ptr += kPrefixLen;
double *pairs = reinterpret_cast<double *>(ptr);
size_t num = (buf_size - kPrefixLen) / 8;
gh_pairs_ = new vector<double>();
for (int i = 0; i < num; i += 10) {
gh_pairs_->push_back(pairs[i]);
}
cout << "GH Pairs saved. Size: " << gh_pairs_->size() << endl;
}

return buffer;
}

void *DummyProcessor::ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) {
auto total_bin_size = cuts_.back();
auto histo_size = total_bin_size*2;
size = kPrefixLen + 8*histo_size*nodes.size();
int64_t buf_size = size;
cout << "ProcessAggregation called with bin size: " << total_bin_size << " Buffer Size: " << buf_size << endl;
std::int8_t *buf = static_cast<std::int8_t *>(calloc(buf_size, 1));
memcpy(buf, kSignature, strlen(kSignature));
memcpy(buf + 8, &buf_size, 8);
memcpy(buf + 16, &processing::kDataTypeHisto, 8);

double *histo = reinterpret_cast<double *>(buf + kPrefixLen);
for ( const auto &node : nodes ) {
auto rows = node.second;
for (const auto &row_id : rows) {

auto num = cuts_.size() - 1;
for (std::size_t f = 0; f < num; f++) {
auto slot = slots_[f + num*row_id];
if (slot < 0) {
continue;
}

if (slot >= total_bin_size) {
cout << "Slot too big, ignored: " << slot << endl;
continue;
}

if (row_id >= gh_pairs_->size()/2) {
cout << "Row ID too big: " << row_id << endl;
}

auto g = (*gh_pairs_)[row_id*2];
auto h = (*gh_pairs_)[row_id*2+1];
histo[slot*2] += g;
histo[slot*2+1] += h;
}
}
histo += histo_size;
}

return buf;
}

std::vector<double> DummyProcessor::HandleAggregation(void *buffer, size_t buf_size) {
cout << "HandleAggregation called with buffer size: " << buf_size << endl;
std::vector<double> result = std::vector<double>();

int8_t* ptr = static_cast<int8_t *>(buffer);
auto rest_size = buf_size;

while (rest_size > kPrefixLen) {
if (!ValidDam(ptr)) {
cout << "Invalid buffer at offset " << buf_size - rest_size << endl;
continue;
}
std::int64_t *size_ptr = reinterpret_cast<std::int64_t *>(ptr + 8);
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen);
auto array_size = (*size_ptr - kPrefixLen)/8;
cout << "Histo size for buffer: " << array_size << endl;
result.insert(result.end(), array_start, array_start + array_size);
cout << "Result size: " << result.size() << endl;
rest_size -= *size_ptr;
ptr = ptr + *size_ptr;
}

cout << "Total histo size: " << result.size() << endl;

return result;
}
53 changes: 53 additions & 0 deletions src/processing/plugins/dummy_processor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
*/
#pragma once
#include <string>
#include <cstring>
#include <vector>
#include <map>
#include "../processor.h"

class DummyProcessor: public processing::Processor {
private:
bool active_ = false;
const std::map<std::string, std::string> *params_{nullptr};
std::vector<double> *gh_pairs_{nullptr};
std::vector<uint32_t> cuts_;
std::vector<int> slots_;

public:
void Initialize(bool active, std::map<std::string, std::string> params) override {
this->active_ = active;
this->params_ = &params;
}

void Shutdown() override {
this->gh_pairs_ = nullptr;
this->cuts_.clear();
this->slots_.clear();
}

void FreeBuffer(void *buffer) override {
free(buffer);
}

void* ProcessGHPairs(size_t &size, std::vector<double>& pairs) override;

void* HandleGHPairs(size_t &size, void *buffer, size_t buf_size) override;

void InitAggregationContext(const std::vector<uint32_t> &cuts, std::vector<int> &slots) override {
std::cout << "InitAggregationContext called with cuts size: " << cuts.size()-1 <<
" number of slot: " << slots.size() << std::endl;
this->cuts_ = cuts;
if (this->slots_.empty()) {
this->slots_ = slots;
} else {
std::cout << "Multiple calls to InitAggregationContext" << std::endl;
}
}

void *ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) override;

std::vector<double> HandleAggregation(void *buffer, size_t buf_size) override;
};
114 changes: 114 additions & 0 deletions src/processing/processor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/**
* Copyright 2014-2024 by XGBoost Contributors
*/
#pragma once

#include <map>
#include <any>
#include <string>
#include <vector>

namespace processing {

const char kLibraryPath[] = "LIBRARY_PATH";
const char kDummyProcessor[] = "dummy";
const char kLoadFunc[] = "LoadProcessor";

// Data type definition
const int64_t kDataTypeGHPairs = 1;
const int64_t kDataTypeHisto = 2;

/*! \brief An processor interface to handle tasks that require external library through plugins */
class Processor {
public:
/*!
* \brief Initialize the processor
*
* \param active If true, this is the active node
* \param params Optional parameters
*/
virtual void Initialize(bool active, std::map<std::string, std::string> params) = 0;

/*!
* \brief Shutdown the processor and free all the resources
*
*/
virtual void Shutdown() = 0;

/*!
* \brief Free buffer
*
* \param buffer Any buffer returned by the calls from the plugin
*/
virtual void FreeBuffer(void* buffer) = 0;

/*!
* \brief Preparing g & h pairs to be sent to other clients by active client
*
* \param size The size of the buffer
* \param pairs g&h pairs in a vector (g1, h1, g2, h2 ...) for every sample
*
* \return The encoded buffer to be sent
*/
virtual void* ProcessGHPairs(size_t &size, std::vector<double>& pairs) = 0;

/*!
* \brief Handle buffers with encoded pairs received from broadcast
*
* \param size Output buffer size
* \param The encoded buffer
* \param The encoded buffer size
*
* \return The encoded buffer
*/
virtual void* HandleGHPairs(size_t &size, void *buffer, size_t buf_size) = 0;

/*!
* \brief Initialize aggregation context by providing global GHistIndexMatrix
*
* \param cuts The cut point for each feature
* \param slots The slot assignment in a flattened matrix for each feature/row. The size is num_feature*num_row
*/
virtual void InitAggregationContext(const std::vector<uint32_t> &cuts, std::vector<int> &slots) = 0;

/*!
* \brief Prepare row set for aggregation
*
* \param size The output buffer size
* \param nodes Map of node and the rows belong to this node
*
* \return The encoded buffer to be sent via AllGather
*/
virtual void *ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) = 0;

/*!
* \brief Handle all gather result
*
* \param buffer Buffer from all gather, only buffer from active site is needed
* \param buf_size The size of the buffer
*
* \return A flattened vector of histograms for each site, each node in the form of
* site1_node1, site1_node2 site1_node3, site2_node1, site2_node2, site2_node3
*/
virtual std::vector<double> HandleAggregation(void *buffer, size_t buf_size) = 0;
};

class ProcessorLoader {
private:
std::map<std::string, std::string> params;
void *handle = NULL;


public:
ProcessorLoader(): params{} {}

ProcessorLoader(std::map<std::string, std::string>& params): params(params) {}

Processor* load(const std::string& plugin_name);

void unload();
};

} // namespace processing

extern processing::Processor *processor_instance;
Loading

0 comments on commit 1107604

Please sign in to comment.