diff --git a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc index aaf9335ec2..749d8e98b5 100644 --- a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc +++ b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc @@ -155,7 +155,7 @@ std::vector NVFlareProcessor::HandleAggregation(void *buffer, size_t buf } void *NVFlareProcessor::ProcessHistograms(size_t *size, const std::vector& histograms) { - cout << "HandleHistograms called with " << histograms.size() << " entries" << endl; + cout << "ProcessHistograms called with " << histograms.size() << " entries" << endl; DamEncoder encoder(kDataSetHistograms); encoder.AddFloatArray(histograms); diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py index 1fe65e7b45..20e0ad6d41 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py @@ -220,7 +220,7 @@ def _check_duplicate_seq(self, op: str, rank: int, seq: int): if event: self.logger.info(f"Duplicate seq {op=} {rank=} {seq=}, wait till original req is done") event.wait(DUPLICATE_REQ_MAX_HOLD_TIME) - time.sleep(1) # To ensure the first request is returned first + time.sleep(1) # To ensure the first request is returned first self.logger.info(f"Duplicate seq {op=} {rank=} {seq=} returned with empty buffer") return True diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py index 39f7e32bd4..f1ca935ea6 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py @@ -100,7 +100,7 @@ class Constant: EVENT_AFTER_ALL_GATHER_V = "xgb.after_all_gather_v" HEADER_KEY_ENCRYPTED_DATA = "xgb.encrypted_data" - HEADER_KEY_ENCRYPTED_HISTOGRAMS = "xgb.encrypted_histograms" + HEADER_KEY_HORIZONTAL = "xgb.horizontal" HEADER_KEY_ORIGINAL_BUF_SIZE = "xgb.original_buf_size" HEADER_KEY_IN_AGGR = "xgb.in_aggr" diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py index cd9333ab00..889f88627d 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py @@ -14,15 +14,15 @@ import os import time -from nvflare.app_opt.he import decomposers - -from nvflare.app_opt.he.homomorphic_encrypt import load_tenseal_context_from_workspace -from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder +import tenseal as ts +from tenseal.tensors.ckksvector import CKKSVector from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable +from nvflare.app_opt.he import decomposers +from nvflare.app_opt.he.homomorphic_encrypt import load_tenseal_context_from_workspace from nvflare.app_opt.xgboost.histogram_based_v2.aggr import Aggregator from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant from nvflare.app_opt.xgboost.histogram_based_v2.mock_he.adder import Adder @@ -37,12 +37,13 @@ generate_keys, split, ) +from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import FeatureAggregationResult -from nvflare.app_opt.xgboost.histogram_based_v2.sec.processor_data_converter import ProcessorDataConverter, \ - DATA_SET_HISTOGRAMS +from nvflare.app_opt.xgboost.histogram_based_v2.sec.processor_data_converter import ( + DATA_SET_HISTOGRAMS, + ProcessorDataConverter, +) from nvflare.app_opt.xgboost.histogram_based_v2.sec.sec_handler import SecurityHandler -import tenseal as ts -from tenseal.tensors.ckksvector import CKKSVector class ClientSecurityHandler(SecurityHandler): @@ -139,14 +140,14 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext): buffer = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) decoder = DamDecoder(buffer) - if not decoder.isValid(): + if not decoder.is_valid(): self.info(fl_ctx, "Not secure content - ignore") return if decoder.get_data_set_id() == DATA_SET_HISTOGRAMS: - self._process_before_all_gather_v_horizontal(fl_ctx, decoder) + self._process_before_all_gather_v_horizontal(fl_ctx) else: - self._process_before_all_gather_v_vertical(fl_ctx, decoder) + self._process_before_all_gather_v_vertical(fl_ctx) def _process_before_all_gather_v_vertical(self, fl_ctx: FLContext): rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) @@ -216,12 +217,15 @@ def _process_before_all_gather_v_horizontal(self, fl_ctx: FLContext): start = time.time() vector = ts.ckks_vector(self.tenseal_context, histograms) self.info( - fl_ctx, f"_process_before_all_gather_v: Histograms with {len(histograms)} entries " - f"encrypted in {time.time()-start} secs" + fl_ctx, + f"_process_before_all_gather_v: Histograms with {len(histograms)} entries " + f"encrypted in {time.time()-start} secs", ) - headers = {Constant.HEADER_KEY_ENCRYPTED_DATA: True, - Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS: True, - Constant.HEADER_KEY_ORIGINAL_BUF_SIZE: len(buffer)} + headers = { + Constant.HEADER_KEY_ENCRYPTED_DATA: True, + Constant.HEADER_KEY_HORIZONTAL: True, + Constant.HEADER_KEY_ORIGINAL_BUF_SIZE: len(buffer), + } fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=vector, private=True, sticky=False) fl_ctx.set_prop(key=Constant.PARAM_KEY_HEADERS, value=headers, private=True, sticky=False) @@ -279,8 +283,8 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): self.info(fl_ctx, "no encrypted result - ignore") return - has_histograms = reply.get_header(Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS) - if has_histograms: + horizontal = reply.get_header(Constant.HEADER_KEY_HORIZONTAL) + if horizontal: self._process_after_all_gather_v_horizontal(fl_ctx) else: self._process_after_all_gather_v_vertical(fl_ctx) @@ -360,11 +364,9 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): try: self.tenseal_context = load_tenseal_context_from_workspace(self.tenseal_context_file, fl_ctx) except Exception as ex: - self.info(fl_ctx, - f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}") + self.info(fl_ctx, f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}") self.tenseal_context = None elif event_type == EventType.END_RUN: self.tenseal_context = None else: super().handle_event(event_type, fl_ctx) - diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/processor_data_converter.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/processor_data_converter.py index 224b60325b..63298c5fb2 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/processor_data_converter.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/processor_data_converter.py @@ -150,4 +150,3 @@ def to_float_array(result: FeatureAggregationResult) -> List[float]: float_array.append(ProcessorDataConverter.int_to_float(h)) return float_array - diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py index 4aad567fe1..119d0f5570 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py @@ -14,11 +14,10 @@ import os import threading -from nvflare.app_opt.he import decomposers - from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable +from nvflare.app_opt.he import decomposers from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant from nvflare.app_opt.xgboost.histogram_based_v2.sec.sec_handler import SecurityHandler @@ -94,18 +93,22 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext): self.info(fl_ctx, "start - non-secure data") return - has_histograms = request.get_header(Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS) - split_mode = "horizontal" if has_histograms else "vertical" + horizontal = request.get_header(Constant.HEADER_KEY_HORIZONTAL) + split_mode = "horizontal" if horizontal else "vertical" self.info(fl_ctx, f"start - {split_mode}") fl_ctx.set_prop(key=Constant.HEADER_KEY_IN_AGGR, value=True, private=True, sticky=False) - fl_ctx.set_prop(key=Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS, value=has_histograms, private=True, sticky=False) + fl_ctx.set_prop(key=Constant.HEADER_KEY_HORIZONTAL, value=horizontal, private=True, sticky=False) rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK) send_buf = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF) if send_buf: + if horizontal: + length = send_buf.size() + else: + length = len(send_buf) # the send_buf contains encoded aggr result (str) or CKKS vector from this rank - self.info(fl_ctx, f"got encrypted aggr data: {len(send_buf)} bytes") + self.info(fl_ctx, f"got encrypted aggr data: {length} bytes") with self.aggr_result_lock: self.aggr_result_to_send = None if not self.aggr_result_dict: @@ -133,22 +136,28 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext): reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY) assert isinstance(reply, Shareable) - has_histograms = fl_ctx.get_prop(Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS) + horizontal = fl_ctx.get_prop(Constant.HEADER_KEY_HORIZONTAL) reply.set_header(Constant.HEADER_KEY_ENCRYPTED_DATA, True) - reply.set_header(Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS, has_histograms) + reply.set_header(Constant.HEADER_KEY_HORIZONTAL, horizontal) with self.aggr_result_lock: if not self.aggr_result_to_send: if not self.aggr_result_dict: return self._abort(f"Rank {rank}: no aggr result after AllGatherV!", fl_ctx) - if has_histograms: + if horizontal: self.aggr_result_to_send = self._histogram_sum(fl_ctx) else: self.aggr_result_to_send = self.aggr_result_dict # reset aggr_result_dict for next gather self.aggr_result_dict = None - self.info(fl_ctx, f"aggr_result_to_send {len(self.aggr_result_to_send)}") + + if horizontal: + length = self.aggr_result_to_send.size() + else: + length = len(self.aggr_result_to_send) + + self.info(fl_ctx, f"aggr_result_to_send {length}") fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=self.aggr_result_to_send, private=True, sticky=False) def _histogram_sum(self, fl_ctx: FLContext): diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py index 69a418235d..f27e9a4c0d 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py @@ -16,9 +16,12 @@ from nvflare.app_opt.xgboost.data_loader import XGBDataLoader +COL_SECURE = 2 +ROW_SECURE = 3 + class SecureDataLoader(XGBDataLoader): - def __init__(self, rank: int, folder: str): + def __init__(self, rank: int, folder: str, data_split_mode=COL_SECURE): """Reads CSV dataset and return XGB data matrix in vertical secure mode. Args: @@ -27,18 +30,19 @@ def __init__(self, rank: int, folder: str): """ self.rank = rank self.folder = folder + self.data_split_mode = data_split_mode def load_data(self, client_id: str): train_path = f"{self.folder}/site-{self.rank + 1}/train.csv" valid_path = f"{self.folder}/site-{self.rank + 1}/valid.csv" - if self.rank == 0: + if self.rank == 0 or self.data_split_mode == ROW_SECURE: label = "&label_column=0" else: label = "" - train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=2) - valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=2) + train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=self.data_split_mode) + valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=self.data_split_mode) return train_data, valid_data