-
Notifications
You must be signed in to change notification settings - Fork 182
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Updated FOBS readme to add DatumManager, added agrpcs as secure scheme * Refactoring * Refactored the secure version to histogram_based_v2 * Replaced Paillier with a mock encryptor * Added license header * Put mock back * Added metrics_writer back and fixed GRPC error reply
- Loading branch information
Showing
71 changed files
with
3,182 additions
and
559 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
cmake_minimum_required(VERSION 3.19) | ||
project(proc_nvflare LANGUAGES CXX C VERSION 1.0) | ||
set(CMAKE_CXX_STANDARD 17) | ||
set(CMAKE_BUILD_TYPE Debug) | ||
|
||
option(GOOGLE_TEST "Build google tests" OFF) | ||
|
||
file(GLOB_RECURSE LIB_SRC | ||
"src/*.h" | ||
"src/*.cc" | ||
) | ||
|
||
add_library(proc_nvflare SHARED ${LIB_SRC}) | ||
set(XGB_SRC ${proc_nvflare_SOURCE_DIR}/../../../../xgboost) | ||
target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include | ||
${XGB_SRC}/src | ||
${XGB_SRC}/rabit/include | ||
${XGB_SRC}/include | ||
${XGB_SRC}/dmlc-core/include) | ||
|
||
link_directories(${XGB_SRC}/lib/) | ||
|
||
if (APPLE) | ||
add_link_options("LINKER:-object_path_lto,$<TARGET_PROPERTY:NAME>_lto.o") | ||
add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache") | ||
endif () | ||
|
||
target_link_libraries(proc_nvflare ${XGB_SRC}/lib/libxgboost${CMAKE_SHARED_LIBRARY_SUFFIX}) | ||
|
||
#-- Unit Tests | ||
if(GOOGLE_TEST) | ||
find_package(GTest REQUIRED) | ||
enable_testing() | ||
add_executable(proc_test) | ||
target_link_libraries(proc_test PRIVATE proc_nvflare) | ||
|
||
|
||
target_include_directories(proc_test PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include | ||
${XGB_SRC}/src | ||
${XGB_SRC}/rabit/include | ||
${XGB_SRC}/include | ||
${XGB_SRC}/dmlc-core/include | ||
${XGB_SRC}/tests) | ||
|
||
add_subdirectory(${proc_nvflare_SOURCE_DIR}/tests) | ||
|
||
add_test( | ||
NAME TestProcessor | ||
COMMAND proc_test | ||
WORKING_DIRECTORY ${proc_nvflare_BINARY_DIR}) | ||
|
||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Build Instruction | ||
|
||
This plugin build requires xgboost source code, checkout xgboost source and build it with FEDERATED plugin, | ||
|
||
cd xgboost | ||
mkdir build | ||
cd build | ||
cmake .. -DPLUGIN_FEDERATED=ON | ||
make | ||
|
||
cd NVFlare/integration/xgboost/processor | ||
mkdir build | ||
cd build | ||
cmake .. | ||
make |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# encoding-plugins | ||
Processor Plugin for NVFlare | ||
|
||
This plugin is a companion for NVFlare based encryption, it processes the data so it can | ||
be properly decoded by Python code running on NVFlare. | ||
|
||
All the encryption is happening on the local GRPC client/server so no encryption is needed | ||
in this plugin. | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# DAM (Direct-Accessible Marshaller) | ||
|
||
A simple serialization library that doesn't have dependencies, and the data | ||
is directly accessible in C/C++ without copying. | ||
|
||
To make the data accessible in C, following rules must be followed, | ||
|
||
1. Numeric values must be stored in native byte-order. | ||
2. Numeric values must start at the 64-bit boundaries (8-bytes) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
/** | ||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <iostream> | ||
#include <cstring> | ||
#include "dam.h" | ||
|
||
void print_buffer(uint8_t *buffer, int size) { | ||
for (int i = 0; i < size; i++) { | ||
auto c = buffer[i]; | ||
std::cout << std::hex << (int) c << " "; | ||
} | ||
std::cout << std::endl << std::dec; | ||
} | ||
|
||
// DamEncoder ====== | ||
void DamEncoder::AddFloatArray(std::vector<double> &value) { | ||
if (encoded) { | ||
std::cout << "Buffer is already encoded" << std::endl; | ||
return; | ||
} | ||
auto buf_size = value.size()*8; | ||
uint8_t *buffer = static_cast<uint8_t *>(malloc(buf_size)); | ||
memcpy(buffer, value.data(), buf_size); | ||
// print_buffer(reinterpret_cast<uint8_t *>(value.data()), value.size() * 8); | ||
entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size())); | ||
} | ||
|
||
void DamEncoder::AddIntArray(std::vector<int64_t> &value) { | ||
std::cout << "AddIntArray called, size: " << value.size() << std::endl; | ||
if (encoded) { | ||
std::cout << "Buffer is already encoded" << std::endl; | ||
return; | ||
} | ||
auto buf_size = value.size()*8; | ||
std::cout << "Allocating " << buf_size << " bytes" << std::endl; | ||
uint8_t *buffer = static_cast<uint8_t *>(malloc(buf_size)); | ||
memcpy(buffer, value.data(), buf_size); | ||
// print_buffer(buffer, buf_size); | ||
entries->push_back(new Entry(kDataTypeIntArray, buffer, value.size())); | ||
} | ||
|
||
std::uint8_t * DamEncoder::Finish(size_t &size) { | ||
encoded = true; | ||
|
||
size = calculate_size(); | ||
auto buf = static_cast<uint8_t *>(malloc(size)); | ||
auto pointer = buf; | ||
memcpy(pointer, kSignature, strlen(kSignature)); | ||
memcpy(pointer+8, &size, 8); | ||
memcpy(pointer+16, &data_set_id, 8); | ||
|
||
pointer += kPrefixLen; | ||
for (auto entry : *entries) { | ||
memcpy(pointer, &entry->data_type, 8); | ||
pointer += 8; | ||
memcpy(pointer, &entry->size, 8); | ||
pointer += 8; | ||
int len = 8*entry->size; | ||
memcpy(pointer, entry->pointer, len); | ||
free(entry->pointer); | ||
pointer += len; | ||
// print_buffer(entry->pointer, entry->size*8); | ||
} | ||
|
||
if ((pointer - buf) != size) { | ||
std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl; | ||
return nullptr; | ||
} | ||
|
||
return buf; | ||
} | ||
|
||
std::size_t DamEncoder::calculate_size() { | ||
auto size = kPrefixLen; | ||
|
||
for (auto entry : *entries) { | ||
size += 16; // The Type and Len | ||
size += entry->size * 8; // All supported data types are 8 bytes | ||
} | ||
|
||
return size; | ||
} | ||
|
||
|
||
// DamDecoder ====== | ||
|
||
DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size) { | ||
this->buffer = buffer; | ||
this->buf_size = size; | ||
this->pos = buffer + kPrefixLen; | ||
if (size >= kPrefixLen) { | ||
memcpy(&len, buffer + 8, 8); | ||
memcpy(&data_set_id, buffer + 16, 8); | ||
} else { | ||
len = 0; | ||
data_set_id = 0; | ||
} | ||
} | ||
|
||
bool DamDecoder::IsValid() { | ||
return buf_size >= kPrefixLen && memcmp(buffer, kSignature, strlen(kSignature)) == 0; | ||
} | ||
|
||
std::vector<int64_t> DamDecoder::DecodeIntArray() { | ||
auto type = *reinterpret_cast<int64_t *>(pos); | ||
if (type != kDataTypeIntArray) { | ||
std::cout << "Data type " << type << " doesn't match Int Array" << std::endl; | ||
return std::vector<int64_t>(); | ||
} | ||
pos += 8; | ||
|
||
auto len = *reinterpret_cast<int64_t *>(pos); | ||
pos += 8; | ||
auto ptr = reinterpret_cast<int64_t *>(pos); | ||
pos += 8*len; | ||
return std::vector<int64_t>(ptr, ptr + len); | ||
} | ||
|
||
std::vector<double> DamDecoder::DecodeFloatArray() { | ||
auto type = *reinterpret_cast<int64_t *>(pos); | ||
if (type != kDataTypeFloatArray) { | ||
std::cout << "Data type " << type << " doesn't match Float Array" << std::endl; | ||
return std::vector<double>(); | ||
} | ||
pos += 8; | ||
|
||
auto len = *reinterpret_cast<int64_t *>(pos); | ||
pos += 8; | ||
|
||
auto ptr = reinterpret_cast<double *>(pos); | ||
pos += 8*len; | ||
return std::vector<double>(ptr, ptr + len); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
/** | ||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
#include <string> | ||
#include <vector> | ||
#include <map> | ||
|
||
const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 | ||
const int kPrefixLen = 24; | ||
|
||
const int kDataTypeInt = 1; | ||
const int kDataTypeFloat = 2; | ||
const int kDataTypeString = 3; | ||
const int kDataTypeIntArray = 257; | ||
const int kDataTypeFloatArray = 258; | ||
|
||
const int kDataTypeMap = 1025; | ||
|
||
class Entry { | ||
public: | ||
int64_t data_type; | ||
uint8_t * pointer; | ||
int64_t size; | ||
|
||
Entry(int64_t data_type, uint8_t *pointer, int64_t size) { | ||
this->data_type = data_type; | ||
this->pointer = pointer; | ||
this->size = size; | ||
} | ||
}; | ||
|
||
class DamEncoder { | ||
private: | ||
bool encoded = false; | ||
int64_t data_set_id; | ||
std::vector<Entry *> *entries = new std::vector<Entry *>(); | ||
|
||
public: | ||
explicit DamEncoder(int64_t data_set_id) { | ||
this->data_set_id = data_set_id; | ||
} | ||
|
||
void AddIntArray(std::vector<int64_t> &value); | ||
|
||
void AddFloatArray(std::vector<double> &value); | ||
|
||
std::uint8_t * Finish(size_t &size); | ||
|
||
private: | ||
std::size_t calculate_size(); | ||
}; | ||
|
||
class DamDecoder { | ||
private: | ||
std::uint8_t *buffer = nullptr; | ||
std::size_t buf_size = 0; | ||
std::uint8_t *pos = nullptr; | ||
std::size_t remaining = 0; | ||
int64_t data_set_id = 0; | ||
int64_t len = 0; | ||
|
||
public: | ||
explicit DamDecoder(std::uint8_t *buffer, std::size_t size); | ||
|
||
size_t Size() { | ||
return len; | ||
} | ||
|
||
int64_t GetDataSetId() { | ||
return data_set_id; | ||
} | ||
|
||
bool IsValid(); | ||
|
||
std::vector<int64_t> DecodeIntArray(); | ||
|
||
std::vector<double> DecodeFloatArray(); | ||
}; | ||
|
||
void print_buffer(uint8_t *buffer, int size); |
Oops, something went wrong.