forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
integration with interface initial attempt
- Loading branch information
Showing
7 changed files
with
406 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/** | ||
* Copyright 2014-2024 by XGBoost Contributors | ||
*/ | ||
#include "./dummy_processor.h" | ||
|
||
using std::vector; | ||
using std::cout; | ||
using std::endl; | ||
|
||
const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 | ||
const int kPrefixLen = 24; | ||
|
||
xgboost::common::Span<int8_t> DummyProcessor::ProcessGHPairs(vector<double> &pairs) { | ||
cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl; | ||
|
||
auto buf_size = kPrefixLen + pairs.size()*10*8; // Assume encrypted size is 10x | ||
|
||
// This memory needs to be freed | ||
char *buf = static_cast<char *>(calloc(buf_size, 1)); | ||
memcpy(buf, kSignature, strlen(kSignature)); | ||
memcpy(buf + 8, &buf_size, 8); | ||
memcpy(buf + 16, &xgboost::processing::kDataTypeGHPairs, 8); | ||
|
||
// Simulate encryption by duplicating value 10 times | ||
int index = kPrefixLen; | ||
for (auto value : pairs) { | ||
for (int i = 0; i < 10; i++) { | ||
memcpy(buf+index, &value, 8); | ||
index += 8; | ||
} | ||
} | ||
|
||
// Save pairs for future operations | ||
this->gh_pairs_ = &pairs; | ||
|
||
return xgboost::common::Span<int8_t>(reinterpret_cast<int8_t *>(buf), buf_size); | ||
} | ||
|
||
xgboost::common::Span<int8_t> DummyProcessor::HandleGHPairs(xgboost::common::Span<int8_t> buffer) { | ||
cout << "HandleGHPairs called with buffer size: " << buffer.size() << endl; | ||
|
||
// For dummy, this call is used to set gh_pairs for passive sites | ||
if (!active_) { | ||
int8_t *ptr = buffer.data() + kPrefixLen; | ||
double *pairs = reinterpret_cast<double *>(ptr); | ||
size_t num = (buffer.size() - kPrefixLen) / 8; | ||
gh_pairs_ = new vector<double>(pairs, pairs + num); | ||
} | ||
|
||
return buffer; | ||
} | ||
|
||
xgboost::common::Span<std::int8_t> DummyProcessor::ProcessAggregation( | ||
std::vector<xgboost::bst_node_t> const &nodes_to_build, xgboost::common::RowSetCollection const &row_set) { | ||
auto total_bin_size = gidx_->Cuts().Values().size(); | ||
auto histo_size = total_bin_size*2; | ||
auto buf_size = kPrefixLen + 8*histo_size*nodes_to_build.size(); | ||
std::int8_t *buf = static_cast<std::int8_t *>(calloc(buf_size, 1)); | ||
memcpy(buf, kSignature, strlen(kSignature)); | ||
memcpy(buf + 8, &buf_size, 8); | ||
memcpy(buf + 16, &xgboost::processing::kDataTypeHisto, 8); | ||
|
||
double *histo = reinterpret_cast<double *>(buf + kPrefixLen); | ||
for (auto &node_id : nodes_to_build) { | ||
auto elem = row_set[node_id]; | ||
for (auto it = elem.begin; it != elem.end; ++it) { | ||
auto row_id = *it; | ||
for (std::size_t f = 0; f < gidx_->Cuts().Ptrs().size()-1; f++) { | ||
auto slot = gidx_->GetGindex(row_id, f); | ||
if (slot < 0) { | ||
continue; | ||
} | ||
|
||
auto g = (*gh_pairs_)[row_id*2]; | ||
auto h = (*gh_pairs_)[row_id*2+1]; | ||
histo[slot*2] += g; | ||
histo[slot*2+1] += h; | ||
} | ||
} | ||
histo += histo_size; | ||
} | ||
|
||
return xgboost::common::Span<int8_t>(reinterpret_cast<int8_t *>(buf), buf_size); | ||
} | ||
|
||
std::vector<double> DummyProcessor::HandleAggregation(xgboost::common::Span<std::int8_t> buffer) { | ||
std::vector<double> result = std::vector<double>(); | ||
|
||
int8_t* ptr = buffer.data(); | ||
auto rest_size = buffer.size(); | ||
|
||
while (rest_size > kPrefixLen) { | ||
std::int64_t *size_ptr = reinterpret_cast<std::int64_t *>(ptr + 8); | ||
double *array_start = reinterpret_cast<double *>(ptr + kPrefixLen); | ||
auto array_size = (*size_ptr - kPrefixLen)/8; | ||
result.insert(result.end(), array_start, array_start + array_size); | ||
|
||
rest_size -= *size_ptr; | ||
ptr = ptr + *size_ptr; | ||
} | ||
|
||
return result; | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/** | ||
* Copyright 2014-2024 by XGBoost Contributors | ||
*/ | ||
#pragma once | ||
#include <string> | ||
#include <vector> | ||
#include <map> | ||
#include "../processor.h" | ||
|
||
class DummyProcessor: public xgboost::processing::Processor { | ||
private: | ||
bool active_ = false; | ||
const std::map<std::string, std::string> *params_; | ||
std::vector<double> *gh_pairs_{nullptr}; | ||
const xgboost::GHistIndexMatrix *gidx_; | ||
|
||
public: | ||
void Initialize(bool active, std::map<std::string, std::string> params) override { | ||
this->active_ = active; | ||
this->params_ = ¶ms; | ||
} | ||
|
||
void Shutdown() override { | ||
this->gh_pairs_ = nullptr; | ||
this->gidx_ = nullptr; | ||
} | ||
|
||
void FreeBuffer(xgboost::common::Span<std::int8_t> buffer) override { | ||
free(buffer.data()); | ||
} | ||
|
||
xgboost::common::Span<int8_t> ProcessGHPairs(std::vector<double> &pairs) override; | ||
|
||
xgboost::common::Span<int8_t> HandleGHPairs(xgboost::common::Span<int8_t> buffer) override; | ||
|
||
void InitAggregationContext(xgboost::GHistIndexMatrix const &gidx) override { | ||
this->gidx_ = &gidx; | ||
} | ||
|
||
xgboost::common::Span<std::int8_t> ProcessAggregation(std::vector<xgboost::bst_node_t> const &nodes_to_build, | ||
xgboost::common::RowSetCollection const &row_set) override; | ||
|
||
std::vector<double> HandleAggregation(xgboost::common::Span<std::int8_t> buffer) override; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/** | ||
* Copyright 2014-2024 by XGBoost Contributors | ||
*/ | ||
#pragma once | ||
|
||
#include <xgboost/span.h> | ||
#include <map> | ||
#include <any> | ||
#include <string> | ||
#include <vector> | ||
#include "../data/gradient_index.h" | ||
|
||
namespace xgboost::processing { | ||
|
||
const char kLibraryPath[] = "LIBRARY_PATH"; | ||
const char kDummyProcessor[] = "dummy"; | ||
const char kLoadFunc[] = "LoadProcessor"; | ||
|
||
// Data type definition | ||
const int kDataTypeGHPairs = 1; | ||
const int kDataTypeHisto = 2; | ||
|
||
/*! \brief An processor interface to handle tasks that require external library through plugins */ | ||
class Processor { | ||
public: | ||
/*! | ||
* \brief Initialize the processor | ||
* | ||
* \param active If true, this is the active node | ||
* \param params Optional parameters | ||
*/ | ||
virtual void Initialize(bool active, std::map<std::string, std::string> params) = 0; | ||
|
||
/*! | ||
* \brief Shutdown the processor and free all the resources | ||
* | ||
*/ | ||
virtual void Shutdown() = 0; | ||
|
||
/*! | ||
* \brief Free buffer | ||
* | ||
* \param buffer Any buffer returned by the calls from the plugin | ||
*/ | ||
virtual void FreeBuffer(common::Span<std::int8_t> buffer) = 0; | ||
|
||
/*! | ||
* \brief Preparing g & h pairs to be sent to other clients by active client | ||
* | ||
* \param pairs g&h pairs in a vector (g1, h1, g2, h2 ...) for every sample | ||
* | ||
* \return The encoded buffer to be sent | ||
*/ | ||
virtual common::Span<std::int8_t> ProcessGHPairs(std::vector<double>& pairs) = 0; | ||
|
||
/*! | ||
* \brief Handle buffers with encoded pairs received from broadcast | ||
* | ||
* \param The encoded buffer | ||
* | ||
* \return The encoded buffer | ||
*/ | ||
virtual common::Span<std::int8_t> HandleGHPairs(common::Span<std::int8_t> buffer) = 0; | ||
|
||
/*! | ||
* \brief Initialize aggregation context by providing global GHistIndexMatrix | ||
* | ||
* \param gidx The matrix for every sample with its feature and slot assignment | ||
*/ | ||
virtual void InitAggregationContext(GHistIndexMatrix const &gidx) = 0; | ||
|
||
/*! | ||
* \brief Prepare row set for aggregation | ||
* | ||
* \param row_set Information for node IDs and its sample IDs | ||
* | ||
* \return The encoded buffer to be sent via AllGather | ||
*/ | ||
virtual common::Span<std::int8_t> ProcessAggregation(std::vector<bst_node_t> const &nodes_to_build, | ||
common::RowSetCollection const &row_set) = 0; | ||
|
||
/*! | ||
* \brief Handle all gather result | ||
* | ||
* \param buffers Buffer from all gather, only buffer from active site is needed | ||
* | ||
* \return A flattened vector of histograms for each site, each node in the form of | ||
* site1_node1, site1_node2 site1_node3, site2_node1, site2_node2, site2_node3 | ||
*/ | ||
virtual std::vector<double> HandleAggregation(common::Span<std::int8_t> buffer) = 0; | ||
}; | ||
|
||
class ProcessorLoader { | ||
private: | ||
std::map<std::string, std::string> params; | ||
void *handle = NULL; | ||
|
||
|
||
public: | ||
ProcessorLoader(): params{} {} | ||
|
||
ProcessorLoader(std::map<std::string, std::string>& params): params(params) {} | ||
|
||
Processor* load(const std::string& plugin_name); | ||
|
||
void unload(); | ||
}; | ||
} // namespace xgboost::processing |
Oops, something went wrong.