forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move processor interface init from learner to communicator
- Loading branch information
Showing
6 changed files
with
379 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/** | ||
* Copyright 2014-2024 by XGBoost Contributors | ||
*/ | ||
#include <iostream> | ||
#include "./dummy_processor.h" | ||
|
||
using std::vector; | ||
using std::cout; | ||
using std::endl; | ||
|
||
const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 | ||
const int64_t kPrefixLen = 24; | ||
|
||
bool ValidDam(void *buffer) { | ||
return memcmp(buffer, kSignature, strlen(kSignature)) == 0; | ||
} | ||
|
||
void* DummyProcessor::ProcessGHPairs(size_t &size, std::vector<double>& pairs) { | ||
cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl; | ||
|
||
size = kPrefixLen + pairs.size()*10*8; // Assume encrypted size is 10x | ||
|
||
int64_t buf_size = size; | ||
// This memory needs to be freed | ||
char *buf = static_cast<char *>(calloc(size, 1)); | ||
memcpy(buf, kSignature, strlen(kSignature)); | ||
memcpy(buf + 8, &buf_size, 8); | ||
memcpy(buf + 16, &processing::kDataTypeGHPairs, 8); | ||
|
||
// Simulate encryption by duplicating value 10 times | ||
int index = kPrefixLen; | ||
for (auto value : pairs) { | ||
for (int i = 0; i < 10; i++) { | ||
memcpy(buf+index, &value, 8); | ||
index += 8; | ||
} | ||
} | ||
|
||
// Save pairs for future operations | ||
this->gh_pairs_ = new vector<double>(pairs); | ||
|
||
return buf; | ||
} | ||
|
||
|
||
void* DummyProcessor::HandleGHPairs(size_t &size, void *buffer, size_t buf_size) { | ||
cout << "HandleGHPairs called with buffer size: " << buf_size << " Active: " << active_ << endl; | ||
|
||
if (!ValidDam(buffer)) { | ||
cout << "Invalid buffer received" << endl; | ||
return buffer; | ||
} | ||
|
||
// For dummy, this call is used to set gh_pairs for passive sites | ||
if (!active_) { | ||
int8_t *ptr = static_cast<int8_t *>(buffer); | ||
ptr += kPrefixLen; | ||
double *pairs = reinterpret_cast<double *>(ptr); | ||
size_t num = (buf_size - kPrefixLen) / 8; | ||
gh_pairs_ = new vector<double>(); | ||
for (int i = 0; i < num; i += 10) { | ||
gh_pairs_->push_back(pairs[i]); | ||
} | ||
cout << "GH Pairs saved. Size: " << gh_pairs_->size() << endl; | ||
} | ||
|
||
return buffer; | ||
} | ||
|
||
void *DummyProcessor::ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) { | ||
auto total_bin_size = cuts_.back(); | ||
auto histo_size = total_bin_size*2; | ||
size = kPrefixLen + 8*histo_size*nodes.size(); | ||
int64_t buf_size = size; | ||
cout << "ProcessAggregation called with bin size: " << total_bin_size << " Buffer Size: " << buf_size << endl; | ||
std::int8_t *buf = static_cast<std::int8_t *>(calloc(buf_size, 1)); | ||
memcpy(buf, kSignature, strlen(kSignature)); | ||
memcpy(buf + 8, &buf_size, 8); | ||
memcpy(buf + 16, &processing::kDataTypeHisto, 8); | ||
|
||
double *histo = reinterpret_cast<double *>(buf + kPrefixLen); | ||
for ( const auto &node : nodes ) { | ||
auto rows = node.second; | ||
for (const auto &row_id : rows) { | ||
|
||
auto num = cuts_.size() - 1; | ||
for (std::size_t f = 0; f < num; f++) { | ||
auto slot = slots_[f + num*row_id]; | ||
if (slot < 0) { | ||
continue; | ||
} | ||
|
||
if (slot >= total_bin_size) { | ||
cout << "Slot too big, ignored: " << slot << endl; | ||
continue; | ||
} | ||
|
||
if (row_id >= gh_pairs_->size()/2) { | ||
cout << "Row ID too big: " << row_id << endl; | ||
} | ||
|
||
auto g = (*gh_pairs_)[row_id*2]; | ||
auto h = (*gh_pairs_)[row_id*2+1]; | ||
histo[slot*2] += g; | ||
histo[slot*2+1] += h; | ||
} | ||
} | ||
histo += histo_size; | ||
} | ||
|
||
return buf; | ||
} | ||
|
||
std::vector<double> DummyProcessor::HandleAggregation(void *buffer, size_t buf_size) { | ||
cout << "HandleAggregation called with buffer size: " << buf_size << endl; | ||
std::vector<double> result = std::vector<double>(); | ||
|
||
int8_t* ptr = static_cast<int8_t *>(buffer); | ||
auto rest_size = buf_size; | ||
|
||
while (rest_size > kPrefixLen) { | ||
if (!ValidDam(ptr)) { | ||
cout << "Invalid buffer at offset " << buf_size - rest_size << endl; | ||
continue; | ||
} | ||
std::int64_t *size_ptr = reinterpret_cast<std::int64_t *>(ptr + 8); | ||
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen); | ||
auto array_size = (*size_ptr - kPrefixLen)/8; | ||
cout << "Histo size for buffer: " << array_size << endl; | ||
result.insert(result.end(), array_start, array_start + array_size); | ||
cout << "Result size: " << result.size() << endl; | ||
rest_size -= *size_ptr; | ||
ptr = ptr + *size_ptr; | ||
} | ||
|
||
cout << "Total histo size: " << result.size() << endl; | ||
|
||
return result; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/** | ||
* Copyright 2014-2024 by XGBoost Contributors | ||
*/ | ||
#pragma once | ||
#include <string> | ||
#include <cstring> | ||
#include <vector> | ||
#include <map> | ||
#include "../processor.h" | ||
|
||
class DummyProcessor: public processing::Processor { | ||
private: | ||
bool active_ = false; | ||
const std::map<std::string, std::string> *params_{nullptr}; | ||
std::vector<double> *gh_pairs_{nullptr}; | ||
std::vector<uint32_t> cuts_; | ||
std::vector<int> slots_; | ||
|
||
public: | ||
void Initialize(bool active, std::map<std::string, std::string> params) override { | ||
this->active_ = active; | ||
this->params_ = ¶ms; | ||
} | ||
|
||
void Shutdown() override { | ||
this->gh_pairs_ = nullptr; | ||
this->cuts_.clear(); | ||
this->slots_.clear(); | ||
} | ||
|
||
void FreeBuffer(void *buffer) override { | ||
free(buffer); | ||
} | ||
|
||
void* ProcessGHPairs(size_t &size, std::vector<double>& pairs) override; | ||
|
||
void* HandleGHPairs(size_t &size, void *buffer, size_t buf_size) override; | ||
|
||
void InitAggregationContext(const std::vector<uint32_t> &cuts, std::vector<int> &slots) override { | ||
std::cout << "InitAggregationContext called with cuts size: " << cuts.size()-1 << | ||
" number of slot: " << slots.size() << std::endl; | ||
this->cuts_ = cuts; | ||
if (this->slots_.empty()) { | ||
this->slots_ = slots; | ||
} else { | ||
std::cout << "Multiple calls to InitAggregationContext" << std::endl; | ||
} | ||
} | ||
|
||
void *ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) override; | ||
|
||
std::vector<double> HandleAggregation(void *buffer, size_t buf_size) override; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/** | ||
* Copyright 2014-2024 by XGBoost Contributors | ||
*/ | ||
#pragma once | ||
|
||
#include <map> | ||
#include <any> | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace processing { | ||
|
||
const char kLibraryPath[] = "LIBRARY_PATH"; | ||
const char kDummyProcessor[] = "dummy"; | ||
const char kLoadFunc[] = "LoadProcessor"; | ||
|
||
// Data type definition | ||
const int64_t kDataTypeGHPairs = 1; | ||
const int64_t kDataTypeHisto = 2; | ||
|
||
/*! \brief An processor interface to handle tasks that require external library through plugins */ | ||
class Processor { | ||
public: | ||
/*! | ||
* \brief Initialize the processor | ||
* | ||
* \param active If true, this is the active node | ||
* \param params Optional parameters | ||
*/ | ||
virtual void Initialize(bool active, std::map<std::string, std::string> params) = 0; | ||
|
||
/*! | ||
* \brief Shutdown the processor and free all the resources | ||
* | ||
*/ | ||
virtual void Shutdown() = 0; | ||
|
||
/*! | ||
* \brief Free buffer | ||
* | ||
* \param buffer Any buffer returned by the calls from the plugin | ||
*/ | ||
virtual void FreeBuffer(void* buffer) = 0; | ||
|
||
/*! | ||
* \brief Preparing g & h pairs to be sent to other clients by active client | ||
* | ||
* \param size The size of the buffer | ||
* \param pairs g&h pairs in a vector (g1, h1, g2, h2 ...) for every sample | ||
* | ||
* \return The encoded buffer to be sent | ||
*/ | ||
virtual void* ProcessGHPairs(size_t &size, std::vector<double>& pairs) = 0; | ||
|
||
/*! | ||
* \brief Handle buffers with encoded pairs received from broadcast | ||
* | ||
* \param size Output buffer size | ||
* \param The encoded buffer | ||
* \param The encoded buffer size | ||
* | ||
* \return The encoded buffer | ||
*/ | ||
virtual void* HandleGHPairs(size_t &size, void *buffer, size_t buf_size) = 0; | ||
|
||
/*! | ||
* \brief Initialize aggregation context by providing global GHistIndexMatrix | ||
* | ||
* \param cuts The cut point for each feature | ||
* \param slots The slot assignment in a flattened matrix for each feature/row. The size is num_feature*num_row | ||
*/ | ||
virtual void InitAggregationContext(const std::vector<uint32_t> &cuts, std::vector<int> &slots) = 0; | ||
|
||
/*! | ||
* \brief Prepare row set for aggregation | ||
* | ||
* \param size The output buffer size | ||
* \param nodes Map of node and the rows belong to this node | ||
* | ||
* \return The encoded buffer to be sent via AllGather | ||
*/ | ||
virtual void *ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) = 0; | ||
|
||
/*! | ||
* \brief Handle all gather result | ||
* | ||
* \param buffer Buffer from all gather, only buffer from active site is needed | ||
* \param buf_size The size of the buffer | ||
* | ||
* \return A flattened vector of histograms for each site, each node in the form of | ||
* site1_node1, site1_node2 site1_node3, site2_node1, site2_node2, site2_node3 | ||
*/ | ||
virtual std::vector<double> HandleAggregation(void *buffer, size_t buf_size) = 0; | ||
}; | ||
|
||
class ProcessorLoader { | ||
private: | ||
std::map<std::string, std::string> params; | ||
void *handle = NULL; | ||
|
||
|
||
public: | ||
ProcessorLoader(): params{} {} | ||
|
||
ProcessorLoader(std::map<std::string, std::string>& params): params(params) {} | ||
|
||
Processor* load(const std::string& plugin_name); | ||
|
||
void unload(); | ||
}; | ||
|
||
} // namespace processing | ||
|
||
extern processing::Processor *processor_instance; |
Oops, something went wrong.