Skip to content

Commit

Permalink
Fixed a few horizontal issues
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed May 6, 2024
1 parent 306b8e1 commit 6d75082
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ std::vector<double> NVFlareProcessor::HandleAggregation(void *buffer, size_t buf
}

void *NVFlareProcessor::ProcessHistograms(size_t *size, const std::vector<double>& histograms) {
cout << "HandleHistograms called with " << histograms.size() << " entries" << endl;
cout << "ProcessHistograms called with " << histograms.size() << " entries" << endl;

DamEncoder encoder(kDataSetHistograms);
encoder.AddFloatArray(histograms);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _check_duplicate_seq(self, op: str, rank: int, seq: int):
if event:
self.logger.info(f"Duplicate seq {op=} {rank=} {seq=}, wait till original req is done")
event.wait(DUPLICATE_REQ_MAX_HOLD_TIME)
time.sleep(1) # To ensure the first request is returned first
time.sleep(1) # To ensure the first request is returned first
self.logger.info(f"Duplicate seq {op=} {rank=} {seq=} returned with empty buffer")
return True

Expand Down
2 changes: 1 addition & 1 deletion nvflare/app_opt/xgboost/histogram_based_v2/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class Constant:
EVENT_AFTER_ALL_GATHER_V = "xgb.after_all_gather_v"

HEADER_KEY_ENCRYPTED_DATA = "xgb.encrypted_data"
HEADER_KEY_ENCRYPTED_HISTOGRAMS = "xgb.encrypted_histograms"
HEADER_KEY_HORIZONTAL = "xgb.horizontal"
HEADER_KEY_ORIGINAL_BUF_SIZE = "xgb.original_buf_size"
HEADER_KEY_IN_AGGR = "xgb.in_aggr"

Expand Down
44 changes: 23 additions & 21 deletions nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
import os
import time

from nvflare.app_opt.he import decomposers

from nvflare.app_opt.he.homomorphic_encrypt import load_tenseal_context_from_workspace
from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder
import tenseal as ts
from tenseal.tensors.ckksvector import CKKSVector

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_opt.he import decomposers
from nvflare.app_opt.he.homomorphic_encrypt import load_tenseal_context_from_workspace
from nvflare.app_opt.xgboost.histogram_based_v2.aggr import Aggregator
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.mock_he.adder import Adder
Expand All @@ -37,12 +37,13 @@
generate_keys,
split,
)
from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder
from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import FeatureAggregationResult
from nvflare.app_opt.xgboost.histogram_based_v2.sec.processor_data_converter import ProcessorDataConverter, \
DATA_SET_HISTOGRAMS
from nvflare.app_opt.xgboost.histogram_based_v2.sec.processor_data_converter import (
DATA_SET_HISTOGRAMS,
ProcessorDataConverter,
)
from nvflare.app_opt.xgboost.histogram_based_v2.sec.sec_handler import SecurityHandler
import tenseal as ts
from tenseal.tensors.ckksvector import CKKSVector


class ClientSecurityHandler(SecurityHandler):
Expand Down Expand Up @@ -139,14 +140,14 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext):
buffer = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF)

decoder = DamDecoder(buffer)
if not decoder.isValid():
if not decoder.is_valid():
self.info(fl_ctx, "Not secure content - ignore")
return

if decoder.get_data_set_id() == DATA_SET_HISTOGRAMS:
self._process_before_all_gather_v_horizontal(fl_ctx, decoder)
self._process_before_all_gather_v_horizontal(fl_ctx)
else:
self._process_before_all_gather_v_vertical(fl_ctx, decoder)
self._process_before_all_gather_v_vertical(fl_ctx)

def _process_before_all_gather_v_vertical(self, fl_ctx: FLContext):
rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK)
Expand Down Expand Up @@ -216,12 +217,15 @@ def _process_before_all_gather_v_horizontal(self, fl_ctx: FLContext):
start = time.time()
vector = ts.ckks_vector(self.tenseal_context, histograms)
self.info(
fl_ctx, f"_process_before_all_gather_v: Histograms with {len(histograms)} entries "
f"encrypted in {time.time()-start} secs"
fl_ctx,
f"_process_before_all_gather_v: Histograms with {len(histograms)} entries "
f"encrypted in {time.time()-start} secs",
)
headers = {Constant.HEADER_KEY_ENCRYPTED_DATA: True,
Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS: True,
Constant.HEADER_KEY_ORIGINAL_BUF_SIZE: len(buffer)}
headers = {
Constant.HEADER_KEY_ENCRYPTED_DATA: True,
Constant.HEADER_KEY_HORIZONTAL: True,
Constant.HEADER_KEY_ORIGINAL_BUF_SIZE: len(buffer),
}
fl_ctx.set_prop(key=Constant.PARAM_KEY_SEND_BUF, value=vector, private=True, sticky=False)
fl_ctx.set_prop(key=Constant.PARAM_KEY_HEADERS, value=headers, private=True, sticky=False)

Expand Down Expand Up @@ -279,8 +283,8 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext):
self.info(fl_ctx, "no encrypted result - ignore")
return

has_histograms = reply.get_header(Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS)
if has_histograms:
horizontal = reply.get_header(Constant.HEADER_KEY_HORIZONTAL)
if horizontal:
self._process_after_all_gather_v_horizontal(fl_ctx)
else:
self._process_after_all_gather_v_vertical(fl_ctx)
Expand Down Expand Up @@ -360,11 +364,9 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
try:
self.tenseal_context = load_tenseal_context_from_workspace(self.tenseal_context_file, fl_ctx)
except Exception as ex:
self.info(fl_ctx,
f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}")
self.info(fl_ctx, f"Can't load tenseal context, horizontal secure XGBoost is not supported: {ex}")
self.tenseal_context = None
elif event_type == EventType.END_RUN:
self.tenseal_context = None
else:
super().handle_event(event_type, fl_ctx)

Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,3 @@ def to_float_array(result: FeatureAggregationResult) -> List[float]:
float_array.append(ProcessorDataConverter.int_to_float(h))

return float_array

29 changes: 19 additions & 10 deletions nvflare/app_opt/xgboost/histogram_based_v2/sec/server_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
import os
import threading

from nvflare.app_opt.he import decomposers

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_opt.he import decomposers
from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant
from nvflare.app_opt.xgboost.histogram_based_v2.sec.sec_handler import SecurityHandler

Expand Down Expand Up @@ -94,18 +93,22 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext):
self.info(fl_ctx, "start - non-secure data")
return

has_histograms = request.get_header(Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS)
split_mode = "horizontal" if has_histograms else "vertical"
horizontal = request.get_header(Constant.HEADER_KEY_HORIZONTAL)
split_mode = "horizontal" if horizontal else "vertical"
self.info(fl_ctx, f"start - {split_mode}")

fl_ctx.set_prop(key=Constant.HEADER_KEY_IN_AGGR, value=True, private=True, sticky=False)
fl_ctx.set_prop(key=Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS, value=has_histograms, private=True, sticky=False)
fl_ctx.set_prop(key=Constant.HEADER_KEY_HORIZONTAL, value=horizontal, private=True, sticky=False)

rank = fl_ctx.get_prop(Constant.PARAM_KEY_RANK)
send_buf = fl_ctx.get_prop(Constant.PARAM_KEY_SEND_BUF)
if send_buf:
if horizontal:
length = send_buf.size()
else:
length = len(send_buf)
# the send_buf contains encoded aggr result (str) or CKKS vector from this rank
self.info(fl_ctx, f"got encrypted aggr data: {len(send_buf)} bytes")
self.info(fl_ctx, f"got encrypted aggr data: {length} bytes")
with self.aggr_result_lock:
self.aggr_result_to_send = None
if not self.aggr_result_dict:
Expand Down Expand Up @@ -133,22 +136,28 @@ def _process_after_all_gather_v(self, fl_ctx: FLContext):

reply = fl_ctx.get_prop(Constant.PARAM_KEY_REPLY)
assert isinstance(reply, Shareable)
has_histograms = fl_ctx.get_prop(Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS)
horizontal = fl_ctx.get_prop(Constant.HEADER_KEY_HORIZONTAL)
reply.set_header(Constant.HEADER_KEY_ENCRYPTED_DATA, True)
reply.set_header(Constant.HEADER_KEY_ENCRYPTED_HISTOGRAMS, has_histograms)
reply.set_header(Constant.HEADER_KEY_HORIZONTAL, horizontal)
with self.aggr_result_lock:
if not self.aggr_result_to_send:
if not self.aggr_result_dict:
return self._abort(f"Rank {rank}: no aggr result after AllGatherV!", fl_ctx)

if has_histograms:
if horizontal:
self.aggr_result_to_send = self._histogram_sum(fl_ctx)
else:
self.aggr_result_to_send = self.aggr_result_dict

# reset aggr_result_dict for next gather
self.aggr_result_dict = None
self.info(fl_ctx, f"aggr_result_to_send {len(self.aggr_result_to_send)}")

if horizontal:
length = self.aggr_result_to_send.size()
else:
length = len(self.aggr_result_to_send)

self.info(fl_ctx, f"aggr_result_to_send {length}")
fl_ctx.set_prop(key=Constant.PARAM_KEY_RCV_BUF, value=self.aggr_result_to_send, private=True, sticky=False)

def _histogram_sum(self, fl_ctx: FLContext):
Expand Down
12 changes: 8 additions & 4 deletions nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

from nvflare.app_opt.xgboost.data_loader import XGBDataLoader

COL_SECURE = 2
ROW_SECURE = 3


class SecureDataLoader(XGBDataLoader):
def __init__(self, rank: int, folder: str):
def __init__(self, rank: int, folder: str, data_split_mode=COL_SECURE):
"""Reads CSV dataset and return XGB data matrix in vertical secure mode.
Args:
Expand All @@ -27,18 +30,19 @@ def __init__(self, rank: int, folder: str):
"""
self.rank = rank
self.folder = folder
self.data_split_mode = data_split_mode

def load_data(self, client_id: str):

train_path = f"{self.folder}/site-{self.rank + 1}/train.csv"
valid_path = f"{self.folder}/site-{self.rank + 1}/valid.csv"

if self.rank == 0:
if self.rank == 0 or self.data_split_mode == ROW_SECURE:
label = "&label_column=0"
else:
label = ""

train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=2)
valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=2)
train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=self.data_split_mode)
valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=self.data_split_mode)

return train_data, valid_data

0 comments on commit 6d75082

Please sign in to comment.