diff --git a/integration/xgboost/processor/src/dam/dam.cc b/integration/xgboost/processor/src/dam/dam.cc index d768d497dd..27c3512946 100644 --- a/integration/xgboost/processor/src/dam/dam.cc +++ b/integration/xgboost/processor/src/dam/dam.cc @@ -26,7 +26,7 @@ void print_buffer(uint8_t *buffer, int size) { } // DamEncoder ====== -void DamEncoder::AddFloatArray(std::vector &value) { +void DamEncoder::AddFloatArray(const std::vector &value) { if (encoded) { std::cout << "Buffer is already encoded" << std::endl; return; @@ -38,7 +38,7 @@ void DamEncoder::AddFloatArray(std::vector &value) { entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size())); } -void DamEncoder::AddIntArray(std::vector &value) { +void DamEncoder::AddIntArray(const std::vector &value) { std::cout << "AddIntArray called, size: " << value.size() << std::endl; if (encoded) { std::cout << "Buffer is already encoded" << std::endl; diff --git a/integration/xgboost/processor/src/include/dam.h b/integration/xgboost/processor/src/include/dam.h index e6afd44299..1f113d92fe 100644 --- a/integration/xgboost/processor/src/include/dam.h +++ b/integration/xgboost/processor/src/include/dam.h @@ -53,9 +53,9 @@ class DamEncoder { this->data_set_id = data_set_id; } - void AddIntArray(std::vector &value); + void AddIntArray(const std::vector &value); - void AddFloatArray(std::vector &value); + void AddFloatArray(const std::vector &value); std::uint8_t * Finish(size_t &size); diff --git a/integration/xgboost/processor/src/include/nvflare_processor.h b/integration/xgboost/processor/src/include/nvflare_processor.h index 52cf42920f..cc6fb6b1a4 100644 --- a/integration/xgboost/processor/src/include/nvflare_processor.h +++ b/integration/xgboost/processor/src/include/nvflare_processor.h @@ -24,6 +24,8 @@ const int kDataSetHGPairs = 1; const int kDataSetAggregation = 2; const int kDataSetAggregationWithFeatures = 3; const int kDataSetAggregationResult = 4; +const int kDataSetHistograms = 5; +const int kDataSetHistogramResult = 6; class NVFlareProcessor: public processing::Processor { private: @@ -51,11 +53,11 @@ class NVFlareProcessor: public processing::Processor { free(buffer); } - void* ProcessGHPairs(size_t &size, std::vector& pairs) override; + void* ProcessGHPairs(size_t *size, const std::vector& pairs) override; - void* HandleGHPairs(size_t &size, void *buffer, size_t buf_size) override; + void* HandleGHPairs(size_t *size, void *buffer, size_t buf_size) override; - void InitAggregationContext(const std::vector &cuts, std::vector &slots) override { + void InitAggregationContext(const std::vector &cuts, const std::vector &slots) override { if (this->slots_.empty()) { this->cuts_ = std::vector(cuts); this->slots_ = std::vector(slots); @@ -64,8 +66,11 @@ class NVFlareProcessor: public processing::Processor { } } - void *ProcessAggregation(size_t &size, std::map> nodes) override; + void *ProcessAggregation(size_t *size, std::map> nodes) override; std::vector HandleAggregation(void *buffer, size_t buf_size) override; + void *ProcessHistograms(size_t *size, const std::vector& histograms) override; + + std::vector HandleHistograms(void *buffer, size_t buf_size) override; }; \ No newline at end of file diff --git a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc index dce1701f7e..749d8e98b5 100644 --- a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc +++ b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc @@ -23,24 +23,24 @@ using std::vector; using std::cout; using std::endl; -void* NVFlareProcessor::ProcessGHPairs(size_t &size, std::vector& pairs) { +void* NVFlareProcessor::ProcessGHPairs(size_t *size, const std::vector& pairs) { cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl; gh_pairs_ = new std::vector(pairs); DamEncoder encoder(kDataSetHGPairs); encoder.AddFloatArray(pairs); - auto buffer = encoder.Finish(size); + auto buffer = encoder.Finish(*size); return buffer; } -void* NVFlareProcessor::HandleGHPairs(size_t &size, void *buffer, size_t buf_size) { +void* NVFlareProcessor::HandleGHPairs(size_t *size, void *buffer, size_t buf_size) { cout << "HandleGHPairs called with buffer size: " << buf_size << " Active: " << active_ << endl; - size = buf_size; + *size = buf_size; return buffer; } -void *NVFlareProcessor::ProcessAggregation(size_t &size, std::map> nodes) { +void *NVFlareProcessor::ProcessAggregation(size_t *size, std::map> nodes) { cout << "ProcessAggregation called with " << nodes.size() << " nodes" << endl; int64_t data_set_id; @@ -107,7 +107,7 @@ void *NVFlareProcessor::ProcessAggregation(size_t &size, std::map NVFlareProcessor::HandleAggregation(void *buffer, size_t buf while (remaining > kPrefixLen) { DamDecoder decoder(reinterpret_cast(pointer), remaining); if (!decoder.IsValid()) { - cout << "Not DAM encoded buffer ignored at offset: " << (int)(pointer - (char *)buffer) << endl; + cout << "Not DAM encoded buffer ignored at offset: " + << static_cast((pointer - reinterpret_cast(buffer))) << endl; break; } auto size = decoder.Size(); @@ -153,6 +154,31 @@ std::vector NVFlareProcessor::HandleAggregation(void *buffer, size_t buf return result; } +void *NVFlareProcessor::ProcessHistograms(size_t *size, const std::vector& histograms) { + cout << "ProcessHistograms called with " << histograms.size() << " entries" << endl; + + DamEncoder encoder(kDataSetHistograms); + encoder.AddFloatArray(histograms); + return encoder.Finish(*size); +} + +std::vector NVFlareProcessor::HandleHistograms(void *buffer, size_t buf_size) { + cout << "HandleHistograms called with buffer size: " << buf_size << endl; + + DamDecoder decoder(reinterpret_cast(buffer), buf_size); + if (!decoder.IsValid()) { + cout << "Not DAM encoded buffer, ignored" << endl; + return std::vector(); + } + + if (decoder.GetDataSetId() != kDataSetHistogramResult) { + cout << "Invalid dataset: " << decoder.GetDataSetId() << endl; + return std::vector(); + } + + return decoder.DecodeFloatArray(); +} + extern "C" { processing::Processor *LoadProcessor(char *plugin_name) { @@ -163,4 +189,5 @@ processing::Processor *LoadProcessor(char *plugin_name) { return new NVFlareProcessor(); } -} + +} // extern "C" diff --git a/nvflare/apis/utils/reliable_message.py b/nvflare/apis/utils/reliable_message.py index 802d2aff2e..8f84613b61 100644 --- a/nvflare/apis/utils/reliable_message.py +++ b/nvflare/apis/utils/reliable_message.py @@ -216,13 +216,14 @@ class ReliableMessage: _logger = logging.getLogger("ReliableMessage") @classmethod - def register_request_handler(cls, topic: str, handler_f): + def register_request_handler(cls, topic: str, handler_f, fl_ctx: FLContext): """Register a handler for the reliable message with this topic Args: topic: The topic of the reliable message handler_f: The callback function to handle the request in the form of handler_f(topic, request, fl_ctx) + fl_ctx: FL Context """ if not cls._enabled: raise RuntimeError("ReliableMessage is not enabled. Please call ReliableMessage.enable() to enable it") @@ -230,6 +231,13 @@ def register_request_handler(cls, topic: str, handler_f): raise TypeError(f"handler_f must be callable but {type(handler_f)}") cls._topic_to_handle[topic] = handler_f + # ReliableMessage also sends aux message directly if tx_timeout is too small + engine = fl_ctx.get_engine() + engine.register_aux_message_handler( + topic=topic, + message_handle_func=handler_f, + ) + @classmethod def _get_or_create_receiver(cls, topic: str, request: Shareable, handler_f) -> _RequestReceiver: tx_id = request.get_header(HEADER_TX_ID) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py index ae5e291ce1..e5fb71d0f5 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading +import time + import grpc import nvflare.app_opt.xgboost.histogram_based_v2.proto.federated_pb2 as pb2 @@ -23,14 +26,12 @@ from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port from nvflare.security.logging import secure_format_exception +DUPLICATE_REQ_MAX_HOLD_TIME = 3600.0 + class GrpcClientAdaptor(XGBClientAdaptor, FederatedServicer): - def __init__( - self, - int_server_grpc_options=None, - in_process=True, - ): - XGBClientAdaptor.__init__(self, in_process) + def __init__(self, int_server_grpc_options=None, in_process=True, per_msg_timeout=10.0, tx_timeout=100.0): + XGBClientAdaptor.__init__(self, in_process, per_msg_timeout, tx_timeout) self.int_server_grpc_options = int_server_grpc_options self.in_process = in_process self.internal_xgb_server = None @@ -41,6 +42,8 @@ def __init__( self._app_dir = None self._workspace = None self._run_dir = None + self._lock = threading.Lock() + self._pending_req = {} def initialize(self, fl_ctx: FLContext): self._client_name = fl_ctx.get_identity_name() @@ -129,34 +132,49 @@ def _abort(self, reason: str): def Allgather(self, request: pb2.AllgatherRequest, context): try: + if self._check_duplicate_seq("allgather", request.rank, request.sequence_number): + return pb2.AllgatherReply(receive_buffer=bytes()) + rcv_buf, _ = self._send_all_gather( rank=request.rank, seq=request.sequence_number, send_buf=request.send_buffer, ) + return pb2.AllgatherReply(receive_buffer=rcv_buf) except Exception as ex: self._abort(reason=f"send_all_gather exception: {secure_format_exception(ex)}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(ex)) return pb2.AllgatherReply(receive_buffer=None) + finally: + self._finish_pending_req("allgather", request.rank, request.sequence_number) def AllgatherV(self, request: pb2.AllgatherVRequest, context): try: + if self._check_duplicate_seq("allgatherv", request.rank, request.sequence_number): + return pb2.AllgatherVReply(receive_buffer=bytes()) + rcv_buf = self._do_all_gather_v( rank=request.rank, seq=request.sequence_number, send_buf=request.send_buffer, ) + return pb2.AllgatherVReply(receive_buffer=rcv_buf) except Exception as ex: self._abort(reason=f"send_all_gather_v exception: {secure_format_exception(ex)}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(ex)) return pb2.AllgatherVReply(receive_buffer=None) + finally: + self._finish_pending_req("allgatherv", request.rank, request.sequence_number) def Allreduce(self, request: pb2.AllreduceRequest, context): try: + if self._check_duplicate_seq("allreduce", request.rank, request.sequence_number): + return pb2.AllreduceReply(receive_buffer=bytes()) + rcv_buf, _ = self._send_all_reduce( rank=request.rank, seq=request.sequence_number, @@ -164,24 +182,58 @@ def Allreduce(self, request: pb2.AllreduceRequest, context): reduce_op=request.reduce_operation, send_buf=request.send_buffer, ) + return pb2.AllreduceReply(receive_buffer=rcv_buf) except Exception as ex: self._abort(reason=f"send_all_reduce exception: {secure_format_exception(ex)}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(ex)) return pb2.AllreduceReply(receive_buffer=None) + finally: + self._finish_pending_req("allreduce", request.rank, request.sequence_number) def Broadcast(self, request: pb2.BroadcastRequest, context): try: + if self._check_duplicate_seq("broadcast", request.rank, request.sequence_number): + return pb2.BroadcastReply(receive_buffer=bytes()) + rcv_buf = self._do_broadcast( rank=request.rank, send_buf=request.send_buffer, seq=request.sequence_number, root=request.root, ) + return pb2.BroadcastReply(receive_buffer=rcv_buf) except Exception as ex: self._abort(reason=f"send_broadcast exception: {secure_format_exception(ex)}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(ex)) return pb2.BroadcastReply(receive_buffer=None) + finally: + self._finish_pending_req("broadcast", request.rank, request.sequence_number) + + def _check_duplicate_seq(self, op: str, rank: int, seq: int): + with self._lock: + event = self._pending_req.get((rank, seq), None) + if event: + self.logger.info(f"Duplicate seq {op=} {rank=} {seq=}, wait till original req is done") + event.wait(DUPLICATE_REQ_MAX_HOLD_TIME) + time.sleep(1) # To ensure the first request is returned first + self.logger.info(f"Duplicate seq {op=} {rank=} {seq=} returned with empty buffer") + return True + + with self._lock: + self._pending_req[(rank, seq)] = threading.Event() + return False + + def _finish_pending_req(self, op: str, rank: int, seq: int): + with self._lock: + event = self._pending_req.get((rank, seq), None) + if not event: + self.logger.error(f"No pending req {op=} {rank=} {seq=}") + return + + event.set() + del self._pending_req[(rank, seq)] + self.logger.info(f"Request seq {op=} {rank=} {seq=} finished processing") diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py index 073a47bdd9..d0ec51de7c 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py @@ -13,10 +13,12 @@ # limitations under the License. from abc import abstractmethod +from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable +from nvflare.apis.utils.reliable_message import ReliableMessage from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant -from nvflare.app_opt.xgboost.histogram_based_v2.sender import Sender +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_positive_int from .adaptor import AppAdaptor @@ -129,28 +131,25 @@ class XGBClientAdaptor(AppAdaptor): XGBClientAdaptor specifies commonly required methods for client adaptor implementations. """ - def __init__(self, in_process): + def __init__(self, in_process, per_msg_timeout: float, tx_timeout: float): """Constructor of XGBClientAdaptor""" AppAdaptor.__init__(self, XGB_APP_NAME, in_process) self.engine = None - self.sender = None self.stopped = False self.rank = None self.num_rounds = None self.world_size = None + self.per_msg_timeout = per_msg_timeout + self.tx_timeout = tx_timeout - def set_sender(self, sender: Sender): - """Set the sender to be used to send XGB operation requests to the server. - - Args: - sender: the sender to be set + def start(self, fl_ctx: FLContext): + pass - Returns: None + def stop(self, fl_ctx: FLContext): + pass - """ - if not isinstance(sender, Sender): - raise TypeError(f"sender must be Sender but got {type(sender)}") - self.sender = sender + def _is_stopped(self) -> (bool, int): + pass def configure(self, config: dict, fl_ctx: FLContext): """Called by XGB Executor to configure the target. @@ -195,8 +194,28 @@ def _send_request(self, op: str, req: Shareable) -> (bytes, Shareable): Returns: operation result """ - reply = self.sender.send_to_server(op, req, self.abort_signal) + req.set_header(Constant.MSG_KEY_XGB_OP, op) + + with self.engine.new_context() as fl_ctx: + reply = ReliableMessage.send_request( + target=FQCN.ROOT_SERVER, + topic=Constant.TOPIC_XGB_REQUEST, + request=req, + per_msg_timeout=self.per_msg_timeout, + tx_timeout=self.tx_timeout, + abort_signal=self.abort_signal, + fl_ctx=fl_ctx, + ) + if isinstance(reply, Shareable): + rc = reply.get_return_code() + if rc != ReturnCode.OK: + raise RuntimeError(f"received error return code: {rc}") + + reply_op = reply.get_header(Constant.MSG_KEY_XGB_OP) + if reply_op != op: + raise RuntimeError(f"received op {reply_op} != expected op {op}") + rcv_buf = reply.get(Constant.PARAM_KEY_RCV_BUF) return rcv_buf, reply else: diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/controller.py b/nvflare/app_opt/xgboost/histogram_based_v2/controller.py index 8040cf2922..39370b7033 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/controller.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/controller.py @@ -20,6 +20,7 @@ from nvflare.apis.impl.controller import Controller from nvflare.apis.shareable import ReturnCode, Shareable, make_reply from nvflare.apis.signal import Signal +from nvflare.apis.utils.reliable_message import ReliableMessage from nvflare.app_opt.xgboost.histogram_based_v2.adaptors.xgb_adaptor import XGBServerAdaptor from nvflare.fuel.utils.validation_utils import check_number_range, check_object_type, check_positive_number, check_str from nvflare.security.logging import secure_format_exception @@ -147,14 +148,15 @@ def start_controller(self, fl_ctx: FLContext): adaptor.initialize(fl_ctx) self.adaptor = adaptor - engine = fl_ctx.get_engine() - engine.register_aux_message_handler( + ReliableMessage.register_request_handler( topic=Constant.TOPIC_XGB_REQUEST, - message_handle_func=self._process_xgb_request, + handler_f=self._process_xgb_request, + fl_ctx=fl_ctx, ) - engine.register_aux_message_handler( + ReliableMessage.register_request_handler( topic=Constant.TOPIC_CLIENT_DONE, - message_handle_func=self._process_client_done, + handler_f=self._process_client_done, + fl_ctx=fl_ctx, ) def _trigger_stop(self, fl_ctx: FLContext, error=None): @@ -328,8 +330,8 @@ def _process_broadcast(self, request: Shareable, fl_ctx: FLContext) -> Shareable send_buf = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) assert isinstance(self.adaptor, XGBServerAdaptor) rcv_buf = self.adaptor.broadcast(rank, seq, root, send_buf, fl_ctx) - reply = Shareable() + reply = Shareable() fl_ctx.set_prop(key=Constant.PARAM_KEY_REPLY, value=reply, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=rcv_buf, private=True, sticky=False) self.fire_event(Constant.EVENT_AFTER_BROADCAST, fl_ctx) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py index 1ee77acd6f..f1ca935ea6 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py @@ -100,7 +100,9 @@ class Constant: EVENT_AFTER_ALL_GATHER_V = "xgb.after_all_gather_v" HEADER_KEY_ENCRYPTED_DATA = "xgb.encrypted_data" + HEADER_KEY_HORIZONTAL = "xgb.horizontal" HEADER_KEY_ORIGINAL_BUF_SIZE = "xgb.original_buf_size" + HEADER_KEY_IN_AGGR = "xgb.in_aggr" DUMMY_BUFFER_SIZE = 4 diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/executor.py b/nvflare/app_opt/xgboost/histogram_based_v2/executor.py index 8c0e727dc5..dc870ff094 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/executor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/executor.py @@ -23,7 +23,6 @@ from nvflare.security.logging import secure_format_exception from .defs import Constant -from .sender import Sender class XGBExecutor(Executor): @@ -32,7 +31,8 @@ def __init__( adaptor_component_id: str, configure_task_name=Constant.CONFIG_TASK_NAME, start_task_name=Constant.START_TASK_NAME, - req_timeout=60.0, + per_msg_timeout=10.0, + tx_timeout=100.0, ): """Constructor @@ -40,10 +40,13 @@ def __init__( adaptor_component_id: the component ID of client target adaptor configure_task_name: name of the config task start_task_name: name of the start task + per_msg_timeout: timeout for sending one message + tx_timeout: transaction timeout """ Executor.__init__(self) self.adaptor_component_id = adaptor_component_id - self.req_timeout = req_timeout + self.per_msg_timeout = per_msg_timeout + self.tx_timeout = tx_timeout self.configure_task_name = configure_task_name self.start_task_name = start_task_name self.adaptor = None @@ -80,8 +83,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): return adaptor.set_abort_signal(self.abort_signal) - engine = fl_ctx.get_engine() - adaptor.set_sender(Sender(engine, self.req_timeout)) adaptor.initialize(fl_ctx) self.adaptor = adaptor elif event_type == EventType.END_RUN: diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/fed_executor.py b/nvflare/app_opt/xgboost/histogram_based_v2/fed_executor.py index 107d5550aa..168d8d328a 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/fed_executor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/fed_executor.py @@ -27,7 +27,8 @@ def __init__( verbose_eval=False, use_gpus=False, int_server_grpc_options=None, - req_timeout=60.0, + per_msg_timeout=60.0, + tx_timeout=600.0, model_file_name="model.json", metrics_writer_id: str = None, in_process=True, @@ -35,7 +36,8 @@ def __init__( XGBExecutor.__init__( self, adaptor_component_id="", - req_timeout=req_timeout, + per_msg_timeout=per_msg_timeout, + tx_timeout=tx_timeout, ) self.early_stopping_rounds = early_stopping_rounds self.xgb_params = xgb_params @@ -61,6 +63,8 @@ def get_adaptor(self, fl_ctx: FLContext): adaptor = GrpcClientAdaptor( int_server_grpc_options=self.int_server_grpc_options, in_process=self.in_process, + per_msg_timeout=self.per_msg_timeout, + tx_timeout=self.tx_timeout, ) adaptor.set_runner(runner) return adaptor diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py index fc836728e9..f3a9dbc905 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py @@ -133,6 +133,10 @@ def run(self, ctx: dict): "federated_server_address": f"{self._server_addr}", "federated_world_size": self._world_size, "federated_rank": self._rank, + "plugin_name": "nvflare", + "loader_params": { + "LIBRARY_PATH": "/tmp", + }, } with xgb.collective.CommunicatorContext(**communicator_env): # Load the data. Dmatrix must be created with column split mode in CommunicatorContext for vertical FL diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py index 570db379c1..38cb7e8644 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py @@ -32,13 +32,28 @@ generate_keys, split, ) +from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import FeatureAggregationResult -from nvflare.app_opt.xgboost.histogram_based_v2.sec.processor_data_converter import ProcessorDataConverter +from nvflare.app_opt.xgboost.histogram_based_v2.sec.processor_data_converter import ( + DATA_SET_HISTOGRAMS, + ProcessorDataConverter, +) from nvflare.app_opt.xgboost.histogram_based_v2.sec.sec_handler import SecurityHandler +try: + import tenseal as ts + from tenseal.tensors.ckksvector import CKKSVector + + from nvflare.app_opt.he import decomposers + from nvflare.app_opt.he.homomorphic_encrypt import load_tenseal_context_from_workspace + + tenseal_imported = True +except Exception: + tenseal_imported = False + class ClientSecurityHandler(SecurityHandler): - def __init__(self, key_length=1024, num_workers=10): + def __init__(self, key_length=1024, num_workers=10, tenseal_context_file="client_context.tenseal"): FLComponent.__init__(self) self.num_workers = num_workers self.key_length = key_length @@ -54,6 +69,11 @@ def __init__(self, key_length=1024, num_workers=10): self.feature_masks = None self.aggregator = Aggregator() self.aggr_result = None # for label client: computed aggr result based on clear-text clear_ghs + self.tenseal_context_file = tenseal_context_file + self.tenseal_context = None + + if tenseal_imported: + decomposers.register() def _process_before_broadcast(self, fl_ctx: FLContext): root = fl_ctx.get_prop(Constant.PARAM_KEY_ROOT) @@ -123,9 +143,22 @@ def _process_after_broadcast(self, fl_ctx: FLContext): fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=dummy_buf, private=True, sticky=False) def _process_before_all_gather_v(self, fl_ctx: FLContext): - rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) self.info(fl_ctx, "start") buffer = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) + + decoder = DamDecoder(buffer) + if not decoder.is_valid(): + self.info(fl_ctx, "Not secure content - ignore") + return + + if decoder.get_data_set_id() == DATA_SET_HISTOGRAMS: + self._process_before_all_gather_v_horizontal(fl_ctx) + else: + self._process_before_all_gather_v_vertical(fl_ctx) + + def _process_before_all_gather_v_vertical(self, fl_ctx: FLContext): + rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) + buffer = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) aggr_ctx = self.data_converter.decode_aggregation_context(buffer, fl_ctx) if not aggr_ctx: @@ -181,6 +214,30 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext): fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=encoded_str, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_HEADERS, value=headers, private=True, sticky=False) + def _process_before_all_gather_v_horizontal(self, fl_ctx: FLContext): + if not self.tenseal_context: + return self._abort( + "Horizontal secure XGBoost not supported due to missing context or missing module", fl_ctx + ) + + buffer = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) + histograms = self.data_converter.decode_histograms(buffer, fl_ctx) + + start = time.time() + vector = ts.ckks_vector(self.tenseal_context, histograms) + self.info( + fl_ctx, + f"_process_before_all_gather_v: Histograms with {len(histograms)} entries " + f"encrypted in {time.time()-start} secs", + ) + headers = { + Constant.HEADER_KEY_ENCRYPTED_DATA: True, + Constant.HEADER_KEY_HORIZONTAL: True, + Constant.HEADER_KEY_ORIGINAL_BUF_SIZE: len(buffer), + } + fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=vector, private=True, sticky=False) + fl_ctx.set_prop(key=Constant.PARAM_KEY_HEADERS, value=headers, private=True, sticky=False) + def _do_aggregation(self, groups, fl_ctx: FLContext): # this is only for the label-client to compute aggregation in clear-text! if not self.feature_masks: @@ -192,13 +249,13 @@ def _do_aggregation(self, groups, fl_ctx: FLContext): fid, masks, num_bins = fm if not groups: gid = 0 - GH_list = self.aggregator.aggregate(self.clear_ghs, masks, num_bins, None) - aggr_result.append((fid, gid, GH_list)) + gh_list = self.aggregator.aggregate(self.clear_ghs, masks, num_bins, None) + aggr_result.append((fid, gid, gh_list)) else: for grp in groups: gid, sample_ids = grp - GH_list = self.aggregator.aggregate(self.clear_ghs, masks, num_bins, sample_ids) - aggr_result.append((fid, gid, GH_list)) + gh_list = self.aggregator.aggregate(self.clear_ghs, masks, num_bins, sample_ids) + aggr_result.append((fid, gid, gh_list)) self.info(fl_ctx, f"aggregated clear-text in {time.time()-t} secs") self.aggr_result = aggr_result @@ -228,7 +285,6 @@ def _decrypt_aggr_result(self, encoded, fl_ctx: FLContext): def _process_after_all_gather_v(self, fl_ctx: FLContext): # called after AllGatherV result is received from the server self.info(fl_ctx, "start") - rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) assert isinstance(reply, Shareable) encrypted_data = reply.get_header(Constant.HEADER_KEY_ENCRYPTED_DATA) @@ -236,9 +292,16 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): self.info(fl_ctx, "no encrypted result - ignore") return - rcv_buf = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) + horizontal = reply.get_header(Constant.HEADER_KEY_HORIZONTAL) + if horizontal: + self._process_after_all_gather_v_horizontal(fl_ctx) + else: + self._process_after_all_gather_v_vertical(fl_ctx) + def _process_after_all_gather_v_vertical(self, fl_ctx: FLContext): + rcv_buf = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) # this rcv_buf is a list of replies from ALL clients! + rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) if not isinstance(rcv_buf, dict): return self._abort(f"rank {rank}: expect a dict of aggr result but got {type(rcv_buf)}", fl_ctx) rank_replies = rcv_buf @@ -269,15 +332,15 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): for a in rr: fid, gid, combined_numbers = a - GH_list = [] + gh_list = [] for n in combined_numbers: - GH_list.append(split(n)) + gh_list.append(split(n)) grp_result = combined_result.get(gid) if not grp_result: grp_result = {} combined_result[gid] = grp_result - grp_result[fid] = FeatureAggregationResult(fid, GH_list) - self.info(fl_ctx, f"aggr from rank {r}: {fid=} {gid=} bins={len(GH_list)}") + grp_result[fid] = FeatureAggregationResult(fid, gh_list) + self.info(fl_ctx, f"aggr from rank {r}: {fid=} {gid=} bins={len(gh_list)}") final_result = {} for gid, far in combined_result.items(): @@ -291,11 +354,31 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): result = self.data_converter.encode_aggregation_result(final_result, fl_ctx) fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) + def _process_after_all_gather_v_horizontal(self, fl_ctx: FLContext): + encrypted_histograms = fl_ctx.get_prop(Constant.PARAM_KEY_RCV_BUF) + rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) + if not isinstance(encrypted_histograms, CKKSVector): + return self._abort(f"rank {rank}: expect a CKKSVector but got {type(encrypted_histograms)}", fl_ctx) + + histograms = encrypted_histograms.decrypt(secret_key=self.tenseal_context.secret_key()) + result = self.data_converter.encode_histograms_result(histograms, fl_ctx) + fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=result, private=True, sticky=False) + def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self.public_key, self.private_key = generate_keys(self.key_length) self.encryptor = Encryptor(self.public_key, self.num_workers) self.decrypter = Decrypter(self.private_key, self.num_workers) self.adder = Adder(self.num_workers) + try: + if tenseal_imported: + self.tenseal_context = load_tenseal_context_from_workspace(self.tenseal_context_file, fl_ctx) + else: + self.debug(fl_ctx, "Tenseal module not loaded, horizontal secure XGBoost is not supported") + except Exception as ex: + self.debug(fl_ctx, f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}") + self.tenseal_context = None + elif event_type == EventType.END_RUN: + self.tenseal_context = None else: super().handle_event(event_type, fl_ctx) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/data_converter.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/data_converter.py index 9a8416ce2b..a9b0a92657 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/data_converter.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/data_converter.py @@ -76,3 +76,28 @@ def encode_aggregation_result( """ pass + + def decode_histograms(self, buffer: bytes, fl_ctx: FLContext) -> List[float]: + """Decode the buffer to extract flattened histograms + + Args: + buffer: buffer to be decoded + fl_ctx: FLContext info + + Returns: if the buffer contains histograms, return the flattened histograms + otherwise, return None + + """ + pass + + def encode_histograms_result(self, histograms: List[float], fl_ctx: FLContext) -> bytes: + """Encode flattened histograms to be sent back to XGBoost + + Args: + histograms: The flattened histograms for all features + fl_ctx: FLContext info + + Returns: a buffer of bytes + + """ + pass diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/processor_data_converter.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/processor_data_converter.py index c6acb7293d..63298c5fb2 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/processor_data_converter.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/processor_data_converter.py @@ -26,6 +26,8 @@ DATA_SET_AGGREGATION = 2 DATA_SET_AGGREGATION_WITH_FEATURES = 3 DATA_SET_AGGREGATION_RESULT = 4 +DATA_SET_HISTOGRAMS = 5 +DATA_SET_HISTOGRAMS_RESULT = 6 SCALE_FACTOR = 1000000.0 # Preserve 6 decimal places @@ -101,6 +103,21 @@ def encode_aggregation_result( return encoder.finish() + def decode_histograms(self, buffer: bytes, fl_ctx: FLContext) -> List[float]: + decoder = DamDecoder(buffer) + if not decoder.is_valid(): + return None + data_set_id = decoder.get_data_set_id() + if data_set_id != DATA_SET_HISTOGRAMS: + raise RuntimeError(f"Invalid DataSet: {data_set_id}") + + return decoder.decode_float_array() + + def encode_histograms_result(self, histograms: List[float], fl_ctx: FLContext) -> bytes: + encoder = DamEncoder(DATA_SET_HISTOGRAMS_RESULT) + encoder.add_float_array(histograms) + return encoder.finish() + @staticmethod def get_bin_size(cuts: [int], feature_id: int) -> int: return cuts[feature_id + 1] - cuts[feature_id] diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py index d24b73dc76..53e936c7d4 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py @@ -20,6 +20,13 @@ from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant from nvflare.app_opt.xgboost.histogram_based_v2.sec.sec_handler import SecurityHandler +try: + from nvflare.app_opt.he import decomposers + + tenseal_imported = True +except Exception: + tenseal_imported = False + class ServerSecurityHandler(SecurityHandler): def __init__(self): @@ -33,6 +40,9 @@ def __init__(self): self.aggr_result_to_send = None self.aggr_result_lock = threading.Lock() + if tenseal_imported: + decomposers.register() + def _process_before_broadcast(self, fl_ctx: FLContext): self.info(fl_ctx, "start") rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) @@ -82,21 +92,30 @@ def _process_after_broadcast(self, fl_ctx: FLContext): fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=self.encrypted_gh, private=True, sticky=False) def _process_before_all_gather_v(self, fl_ctx: FLContext): - self.info(fl_ctx, "start") - rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) request = fl_ctx.get_prop(Constant.PARAM_KEY_REQUEST) assert isinstance(request, Shareable) has_encrypted_data = request.get_header(Constant.HEADER_KEY_ENCRYPTED_DATA) self.info(fl_ctx, f"{has_encrypted_data=}") if not has_encrypted_data: - self.info(fl_ctx, "no encrypted data - ignore") + self.info(fl_ctx, "start - non-secure data") return - fl_ctx.set_prop(key="in_aggr", value=True, private=True, sticky=False) + horizontal = request.get_header(Constant.HEADER_KEY_HORIZONTAL) + split_mode = "horizontal" if horizontal else "vertical" + self.info(fl_ctx, f"start - {split_mode}") + + fl_ctx.set_prop(key=Constant.HEADER_KEY_IN_AGGR, value=True, private=True, sticky=False) + fl_ctx.set_prop(key=Constant.HEADER_KEY_HORIZONTAL, value=horizontal, private=True, sticky=False) + + rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) send_buf = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) if send_buf: - # the send_buf contains encoded aggr result (str) from this rank - self.info(fl_ctx, f"got encrypted aggr data: {len(send_buf)} bytes") + if horizontal: + length = send_buf.size() + else: + length = len(send_buf) + # the send_buf contains encoded aggr result (str) or CKKS vector from this rank + self.info(fl_ctx, f"got encrypted aggr data: {length} bytes") with self.aggr_result_lock: self.aggr_result_to_send = None if not self.aggr_result_dict: @@ -113,9 +132,9 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext): def _process_after_all_gather_v(self, fl_ctx: FLContext): # this is called after the Server has finished gathering - # Note: this fl_ctx is the same as the one in _handle_before_all_gather_v! + # Note: this fl_ctx is the same as the one in _process_before_all_gather_v! rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) - in_aggr = fl_ctx.get_prop("in_aggr") + in_aggr = fl_ctx.get_prop(Constant.HEADER_KEY_IN_AGGR) self.info(fl_ctx, f"start {in_aggr=}") if not in_aggr: @@ -124,14 +143,38 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) assert isinstance(reply, Shareable) + horizontal = fl_ctx.get_prop(Constant.HEADER_KEY_HORIZONTAL) reply.set_header(Constant.HEADER_KEY_ENCRYPTED_DATA, True) + reply.set_header(Constant.HEADER_KEY_HORIZONTAL, horizontal) with self.aggr_result_lock: if not self.aggr_result_to_send: if not self.aggr_result_dict: return self._abort(f"Rank {rank}: no aggr result after AllGatherV!", fl_ctx) - self.aggr_result_to_send = self.aggr_result_dict + + if horizontal: + self.aggr_result_to_send = self._histogram_sum(fl_ctx) + else: + self.aggr_result_to_send = self.aggr_result_dict # reset aggr_result_dict for next gather self.aggr_result_dict = None - self.info(fl_ctx, f"aggr_result_to_send {len(self.aggr_result_to_send)}") + + if horizontal: + length = self.aggr_result_to_send.size() + else: + length = len(self.aggr_result_to_send) + + self.info(fl_ctx, f"aggr_result_to_send {length}") fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=self.aggr_result_to_send, private=True, sticky=False) + + def _histogram_sum(self, fl_ctx: FLContext): + + result = None + + for rank, vector in self.aggr_result_dict.items(): + if not result: + result = vector + else: + result = result + vector + + return result diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py index 69a418235d..f27e9a4c0d 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py @@ -16,9 +16,12 @@ from nvflare.app_opt.xgboost.data_loader import XGBDataLoader +COL_SECURE = 2 +ROW_SECURE = 3 + class SecureDataLoader(XGBDataLoader): - def __init__(self, rank: int, folder: str): + def __init__(self, rank: int, folder: str, data_split_mode=COL_SECURE): """Reads CSV dataset and return XGB data matrix in vertical secure mode. Args: @@ -27,18 +30,19 @@ def __init__(self, rank: int, folder: str): """ self.rank = rank self.folder = folder + self.data_split_mode = data_split_mode def load_data(self, client_id: str): train_path = f"{self.folder}/site-{self.rank + 1}/train.csv" valid_path = f"{self.folder}/site-{self.rank + 1}/valid.csv" - if self.rank == 0: + if self.rank == 0 or self.data_split_mode == ROW_SECURE: label = "&label_column=0" else: label = "" - train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=2) - valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=2) + train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=self.data_split_mode) + valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=self.data_split_mode) return train_data, valid_data diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sender.py b/nvflare/app_opt/xgboost/histogram_based_v2/sender.py deleted file mode 100644 index 7177fbb214..0000000000 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sender.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from nvflare.apis.shareable import ReturnCode, Shareable -from nvflare.apis.signal import Signal -from nvflare.fuel.f3.cellnet.fqcn import FQCN -from nvflare.fuel.utils.obj_utils import get_logger - -from .defs import Constant - - -class Sender: - """ - A Sender is used to send XGB requests from the client to the server and wait for reply. - TBD: currently the sender simply sends the request with an aux message. It will be enhanced to be more - reliable in dealing with unstable network. - """ - - def __init__(self, engine, timeout): - """Constructor - - Args: - engine: the client engine that can send aux messages - timeout: the timeout for XGB requests - """ - self.engine = engine - self.timeout = timeout - self.logger = get_logger(self) - - def _extract_result(self, reply, expected_op): - if not reply: - return None - if not isinstance(reply, dict): - self.logger.error(f"expect reply to be a dict but got {type(reply)}") - return None - result = reply.get(FQCN.ROOT_SERVER) - if not result: - self.logger.error(f"no reply from {FQCN.ROOT_SERVER} for request {expected_op}") - return None - if not isinstance(result, Shareable): - self.logger.error(f"expect result to be a Shareable but got {type(result)}") - return None - rc = result.get_return_code() - if rc != ReturnCode.OK: - self.logger.error(f"server failed to process request: {rc=}") - return None - reply_op = result.get_header(Constant.MSG_KEY_XGB_OP) - if reply_op != expected_op: - self.logger.error(f"received op {reply_op} != expected op {expected_op}") - return None - return result - - def send_to_server(self, op: str, req: Shareable, abort_signal: Signal): - """Send an XGB request to the server. - - Args: - op: the XGB operation code - req: the XGB request - abort_signal: used for checking whether the job is aborted. - - Returns: reply from the server - - Note: when this method is enhanced to be more reliable, we'll keep resending until either the request is - sent successfully or the job is aborted. - - """ - req.set_header(Constant.MSG_KEY_XGB_OP, op) - - server_name = FQCN.ROOT_SERVER - with self.engine.new_context() as fl_ctx: - reply = self.engine.send_aux_request( - targets=[server_name], - topic=Constant.TOPIC_XGB_REQUEST, - request=req, - timeout=self.timeout, - fl_ctx=fl_ctx, - ) - return self._extract_result(reply, op)