Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horizontal Secure XGBoost Support #2562

Merged
merged 33 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a655d9d
Updated FOBS readme to add DatumManager, added agrpcs as secure scheme
nvidianz Mar 18, 2024
4e5ba5d
Merge branch 'NVIDIA:main' into main
nvidianz Mar 25, 2024
9f90d48
Merge branch 'NVIDIA:main' into main
nvidianz Apr 10, 2024
84fc5bd
Merge branch 'NVIDIA:main' into main
nvidianz Apr 15, 2024
f972506
Merge branch 'NVIDIA:main' into main
nvidianz Apr 23, 2024
5f6e76f
Merge branch 'NVIDIA:main' into main
nvidianz Apr 26, 2024
15884c4
Merge branch 'NVIDIA:main' into main
nvidianz May 3, 2024
6fbcbaa
Merge branch 'NVIDIA:main' into main
nvidianz May 6, 2024
2646f60
Implemented horizontal calls in nvflare plugin
nvidianz May 3, 2024
306b8e1
Added support for horizontal secure XGBoost
nvidianz May 5, 2024
6d75082
Fixed a few horizontal issues
nvidianz May 6, 2024
213ca1d
Added reliable message
nvidianz May 8, 2024
07d3787
Merge branch 'NVIDIA:main' into main
nvidianz May 8, 2024
d946eaa
Added ReliableMessage parameters
nvidianz May 8, 2024
bb4a934
Added log for debugging empty rcv_buf
nvidianz May 10, 2024
2037db4
Added finally block to finish duplicate seq
nvidianz May 10, 2024
54af72a
Removed debug statements
nvidianz May 10, 2024
c2ffb3b
format change
nvidianz May 10, 2024
98289c4
Merge branch 'NVIDIA:main' into main
nvidianz May 13, 2024
cb8cdef
Merge branch 'NVIDIA:main' into main
nvidianz May 14, 2024
b43241b
Add in process client api tests (#2549)
YuanTingHsieh May 7, 2024
918f248
Add client controller executor (#2530)
SYangster May 7, 2024
c7d1bee
Add option in dashboard cli for AWS vpc and subnet
IsaacYangSLA May 8, 2024
5c6923f
[2.5] Clean up to allow creation of nvflare light (#2573)
yanchengnv May 9, 2024
23e4da2
Enable patch and build for nvflight (#2574)
IsaacYangSLA May 9, 2024
1ba16d8
add FedBN Implementation on NVFlare research folder - a local batch n…
MinghuiChen43 May 10, 2024
8fd2291
fix MLFLOW example (#2575)
chesterxgchen May 11, 2024
fce5ebb
BugFix: InProcessClientAPIExecutor's TaskScriptRunner (#2558)
chesterxgchen May 11, 2024
b11d6e3
update client_api.png (#2577)
chesterxgchen May 14, 2024
44d9136
Fix the simulator worker sys path (#2561)
yhwen May 14, 2024
ffa9b0b
Merge branch 'main' into flare-1935-xgb-sec-horizontal
nvidianz May 14, 2024
a85f0bd
ReliableMessage register is changed to register aux message. Added su…
nvidianz May 15, 2024
75572ae
Merge branch 'main' into flare-1935-xgb-sec-horizontal
nvidianz May 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions integration/xgboost/processor/src/dam/dam.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void print_buffer(uint8_t *buffer, int size) {
}

// DamEncoder ======
void DamEncoder::AddFloatArray(std::vector<double> &value) {
void DamEncoder::AddFloatArray(const std::vector<double> &value) {
if (encoded) {
std::cout << "Buffer is already encoded" << std::endl;
return;
Expand All @@ -38,7 +38,7 @@ void DamEncoder::AddFloatArray(std::vector<double> &value) {
entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size()));
}

void DamEncoder::AddIntArray(std::vector<int64_t> &value) {
void DamEncoder::AddIntArray(const std::vector<int64_t> &value) {
std::cout << "AddIntArray called, size: " << value.size() << std::endl;
if (encoded) {
std::cout << "Buffer is already encoded" << std::endl;
Expand Down
4 changes: 2 additions & 2 deletions integration/xgboost/processor/src/include/dam.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class DamEncoder {
this->data_set_id = data_set_id;
}

void AddIntArray(std::vector<int64_t> &value);
void AddIntArray(const std::vector<int64_t> &value);

void AddFloatArray(std::vector<double> &value);
void AddFloatArray(const std::vector<double> &value);

std::uint8_t * Finish(size_t &size);

Expand Down
13 changes: 9 additions & 4 deletions integration/xgboost/processor/src/include/nvflare_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ const int kDataSetHGPairs = 1;
const int kDataSetAggregation = 2;
const int kDataSetAggregationWithFeatures = 3;
const int kDataSetAggregationResult = 4;
const int kDataSetHistograms = 5;
const int kDataSetHistogramResult = 6;

class NVFlareProcessor: public processing::Processor {
private:
Expand Down Expand Up @@ -51,11 +53,11 @@ class NVFlareProcessor: public processing::Processor {
free(buffer);
}

void* ProcessGHPairs(size_t &size, std::vector<double>& pairs) override;
void* ProcessGHPairs(size_t *size, const std::vector<double>& pairs) override;
nvidianz marked this conversation as resolved.
Show resolved Hide resolved

void* HandleGHPairs(size_t &size, void *buffer, size_t buf_size) override;
void* HandleGHPairs(size_t *size, void *buffer, size_t buf_size) override;

void InitAggregationContext(const std::vector<uint32_t> &cuts, std::vector<int> &slots) override {
void InitAggregationContext(const std::vector<uint32_t> &cuts, const std::vector<int> &slots) override {
if (this->slots_.empty()) {
this->cuts_ = std::vector<uint32_t>(cuts);
this->slots_ = std::vector<int>(slots);
Expand All @@ -64,8 +66,11 @@ class NVFlareProcessor: public processing::Processor {
}
}

void *ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) override;
void *ProcessAggregation(size_t *size, std::map<int, std::vector<int>> nodes) override;

std::vector<double> HandleAggregation(void *buffer, size_t buf_size) override;

void *ProcessHistograms(size_t *size, const std::vector<double>& histograms) override;

std::vector<double> HandleHistograms(void *buffer, size_t buf_size) override;
};
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,24 @@ using std::vector;
using std::cout;
using std::endl;

void* NVFlareProcessor::ProcessGHPairs(size_t &size, std::vector<double>& pairs) {
void* NVFlareProcessor::ProcessGHPairs(size_t *size, const std::vector<double>& pairs) {
cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl;
gh_pairs_ = new std::vector<double>(pairs);

DamEncoder encoder(kDataSetHGPairs);
encoder.AddFloatArray(pairs);
auto buffer = encoder.Finish(size);
auto buffer = encoder.Finish(*size);

return buffer;
}

void* NVFlareProcessor::HandleGHPairs(size_t &size, void *buffer, size_t buf_size) {
void* NVFlareProcessor::HandleGHPairs(size_t *size, void *buffer, size_t buf_size) {
cout << "HandleGHPairs called with buffer size: " << buf_size << " Active: " << active_ << endl;
size = buf_size;
*size = buf_size;
return buffer;
}

void *NVFlareProcessor::ProcessAggregation(size_t &size, std::map<int, std::vector<int>> nodes) {
void *NVFlareProcessor::ProcessAggregation(size_t *size, std::map<int, std::vector<int>> nodes) {
cout << "ProcessAggregation called with " << nodes.size() << " nodes" << endl;

int64_t data_set_id;
Expand Down Expand Up @@ -107,7 +107,7 @@ void *NVFlareProcessor::ProcessAggregation(size_t &size, std::map<int, std::vect
encoder.AddIntArray(rows);
}

auto buffer = encoder.Finish(size);
auto buffer = encoder.Finish(*size);
return buffer;
}

Expand All @@ -124,7 +124,8 @@ std::vector<double> NVFlareProcessor::HandleAggregation(void *buffer, size_t buf
while (remaining > kPrefixLen) {
DamDecoder decoder(reinterpret_cast<uint8_t *>(pointer), remaining);
if (!decoder.IsValid()) {
cout << "Not DAM encoded buffer ignored at offset: " << (int)(pointer - (char *)buffer) << endl;
cout << "Not DAM encoded buffer ignored at offset: "
<< static_cast<int>((pointer - reinterpret_cast<char *>(buffer))) << endl;
break;
}
auto size = decoder.Size();
Expand Down Expand Up @@ -153,6 +154,31 @@ std::vector<double> NVFlareProcessor::HandleAggregation(void *buffer, size_t buf
return result;
}

void *NVFlareProcessor::ProcessHistograms(size_t *size, const std::vector<double>& histograms) {
cout << "ProcessHistograms called with " << histograms.size() << " entries" << endl;

DamEncoder encoder(kDataSetHistograms);
encoder.AddFloatArray(histograms);
return encoder.Finish(*size);
}

std::vector<double> NVFlareProcessor::HandleHistograms(void *buffer, size_t buf_size) {
nvidianz marked this conversation as resolved.
Show resolved Hide resolved
cout << "HandleHistograms called with buffer size: " << buf_size << endl;

DamDecoder decoder(reinterpret_cast<uint8_t *>(buffer), buf_size);
if (!decoder.IsValid()) {
cout << "Not DAM encoded buffer, ignored" << endl;
return std::vector<double>();
}

if (decoder.GetDataSetId() != kDataSetHistogramResult) {
cout << "Invalid dataset: " << decoder.GetDataSetId() << endl;
return std::vector<double>();
}

return decoder.DecodeFloatArray();
}

extern "C" {

processing::Processor *LoadProcessor(char *plugin_name) {
Expand All @@ -163,4 +189,5 @@ processing::Processor *LoadProcessor(char *plugin_name) {

return new NVFlareProcessor();
}
}

} // extern "C"
10 changes: 9 additions & 1 deletion nvflare/apis/utils/reliable_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,20 +216,28 @@ class ReliableMessage:
_logger = logging.getLogger("ReliableMessage")

@classmethod
def register_request_handler(cls, topic: str, handler_f):
def register_request_handler(cls, topic: str, handler_f, fl_ctx: FLContext):
"""Register a handler for the reliable message with this topic

Args:
topic: The topic of the reliable message
handler_f: The callback function to handle the request in the form of
handler_f(topic, request, fl_ctx)
fl_ctx: FL Context
"""
if not cls._enabled:
raise RuntimeError("ReliableMessage is not enabled. Please call ReliableMessage.enable() to enable it")
if not callable(handler_f):
raise TypeError(f"handler_f must be callable but {type(handler_f)}")
cls._topic_to_handle[topic] = handler_f

# ReliableMessage also sends aux message directly if tx_timeout is too small
engine = fl_ctx.get_engine()
engine.register_aux_message_handler(
nvidianz marked this conversation as resolved.
Show resolved Hide resolved
topic=topic,
message_handle_func=handler_f,
)

@classmethod
def _get_or_create_receiver(cls, topic: str, request: Shareable, handler_f) -> _RequestReceiver:
tx_id = request.get_header(HEADER_TX_ID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import time

import grpc

import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2
Expand All @@ -23,14 +26,12 @@
from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port
from nvflare.security.logging import secure_format_exception

DUPLICATE_REQ_MAX_HOLD_TIME = 3600.0


class GrpcClientAdaptor(XGBClientAdaptor, FederatedServicer):
def __init__(
self,
int_server_grpc_options=None,
in_process=True,
):
XGBClientAdaptor.__init__(self, in_process)
def __init__(self, int_server_grpc_options=None, in_process=True, per_msg_timeout=10.0, tx_timeout=100.0):
XGBClientAdaptor.__init__(self, in_process, per_msg_timeout, tx_timeout)
self.int_server_grpc_options = int_server_grpc_options
self.in_process = in_process
self.internal_xgb_server = None
Expand All @@ -41,6 +42,8 @@ def __init__(
self._app_dir = None
self._workspace = None
self._run_dir = None
self._lock = threading.Lock()
self._pending_req = {}

def initialize(self, fl_ctx: FLContext):
self._client_name = fl_ctx.get_identity_name()
Expand Down Expand Up @@ -129,59 +132,108 @@ def _abort(self, reason: str):

def Allgather(self, request: pb2.AllgatherRequest, context):
try:
if self._check_duplicate_seq("allgather", request.rank, request.sequence_number):
return pb2.AllgatherReply(receive_buffer=bytes())

rcv_buf, _ = self._send_all_gather(
rank=request.rank,
seq=request.sequence_number,
send_buf=request.send_buffer,
)

return pb2.AllgatherReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_all_gather exception: {secure_format_exception(ex)}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(ex))
return pb2.AllgatherReply(receive_buffer=None)
finally:
self._finish_pending_req("allgather", request.rank, request.sequence_number)

def AllgatherV(self, request: pb2.AllgatherVRequest, context):
try:
if self._check_duplicate_seq("allgatherv", request.rank, request.sequence_number):
return pb2.AllgatherVReply(receive_buffer=bytes())

rcv_buf = self._do_all_gather_v(
rank=request.rank,
seq=request.sequence_number,
send_buf=request.send_buffer,
)

return pb2.AllgatherVReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_all_gather_v exception: {secure_format_exception(ex)}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(ex))
return pb2.AllgatherVReply(receive_buffer=None)
finally:
self._finish_pending_req("allgatherv", request.rank, request.sequence_number)

def Allreduce(self, request: pb2.AllreduceRequest, context):
try:
if self._check_duplicate_seq("allreduce", request.rank, request.sequence_number):
return pb2.AllreduceReply(receive_buffer=bytes())

rcv_buf, _ = self._send_all_reduce(
rank=request.rank,
seq=request.sequence_number,
data_type=request.data_type,
reduce_op=request.reduce_operation,
send_buf=request.send_buffer,
)

return pb2.AllreduceReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_all_reduce exception: {secure_format_exception(ex)}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(ex))
return pb2.AllreduceReply(receive_buffer=None)
finally:
self._finish_pending_req("allreduce", request.rank, request.sequence_number)

def Broadcast(self, request: pb2.BroadcastRequest, context):
try:
if self._check_duplicate_seq("broadcast", request.rank, request.sequence_number):
return pb2.BroadcastReply(receive_buffer=bytes())

rcv_buf = self._do_broadcast(
rank=request.rank,
send_buf=request.send_buffer,
seq=request.sequence_number,
root=request.root,
)

return pb2.BroadcastReply(receive_buffer=rcv_buf)
except Exception as ex:
self._abort(reason=f"send_broadcast exception: {secure_format_exception(ex)}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(ex))
return pb2.BroadcastReply(receive_buffer=None)
finally:
self._finish_pending_req("broadcast", request.rank, request.sequence_number)

def _check_duplicate_seq(self, op: str, rank: int, seq: int):
with self._lock:
event = self._pending_req.get((rank, seq), None)
if event:
self.logger.info(f"Duplicate seq {op=} {rank=} {seq=}, wait till original req is done")
event.wait(DUPLICATE_REQ_MAX_HOLD_TIME)
time.sleep(1) # To ensure the first request is returned first
self.logger.info(f"Duplicate seq {op=} {rank=} {seq=} returned with empty buffer")
return True

with self._lock:
self._pending_req[(rank, seq)] = threading.Event()
return False

def _finish_pending_req(self, op: str, rank: int, seq: int):
with self._lock:
event = self._pending_req.get((rank, seq), None)
if not event:
self.logger.error(f"No pending req {op=} {rank=} {seq=}")
return

event.set()
del self._pending_req[(rank, seq)]
self.logger.info(f"Request seq {op=} {rank=} {seq=} finished processing")
Loading
Loading