Skip to content

Commit

Permalink
move HE context part out of FL process to better accomodate the trans…
Browse files Browse the repository at this point in the history
…ition to real application
  • Loading branch information
ZiyueXu77 committed Jan 10, 2024
1 parent 8737799 commit 8cb6d7d
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 27 deletions.
7 changes: 6 additions & 1 deletion examples/advanced/kaplan-meier-he/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The Flare Workflow Communicator API provides the functionality of customized mes
Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) does not support [simulator mode](https://nvflare.readthedocs.io/en/main/getting_started.html), the main reason is that the HE context information (specs and keys) needs to be provisioned before initializing the federated job. For the same reason, it is not straightforward for users to try different HE schemes beyond our existing support for [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py).

With the Flare Workflow Communicator API, such "proof of concept" experiment becomes easy (of course, secure provisioning is still the way to go for real-life federated applications). In this example, the federated analysis pipeline includes 3 rounds:
1. Server generate and distribute the HE context to clients, and remove the private key on server side. Again, this step is done by secure provisioning for real-life applications, but for simulator, we use this step to distribute the HE context.
1. Server send the simple start message without any payload.
2. Clients collect the information of the local maximum time bin number and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable.
3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients.

Expand All @@ -45,6 +45,11 @@ Then we run a 5-client federated job with simulator, begin with splitting and ge
```commandline
python utils/prepare_data.py --out_path "/tmp/flare/dataset/km_data"
```
Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but for simulator, we use this step to distribute the HE context.
```commandline
python utils/prepare_he_context.py --out_path "/tmp/flare/he_context"
```

And we can run the federated job:
```commandline
nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
app_script = "kaplan_meier_train.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = "--data_root /tmp/flare/dataset/km_data"
app_config = "--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt"

# Client Computing Executors.
executors = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
wf_class_path = "kaplan_meier_wf.KM",
wf_args {
min_clients = 5
he_context_path = "/tmp/flare/he_context/he_context_server.txt"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import argparse
import base64
import json
import os

Expand All @@ -29,6 +30,12 @@


# Client code
def read_data(file_name: str):
with open(file_name, "rb") as f:
data = f.read()
return base64.b64decode(data)


def details_save(kmf):
# Get the survival function at all observed time points
survival_function_at_all_times = kmf.survival_function_
Expand Down Expand Up @@ -71,6 +78,7 @@ def plot_and_save(kmf):
def main():
parser = argparse.ArgumentParser(description="KM analysis")
parser.add_argument("--data_root", type=str, help="Root path for data files")
parser.add_argument("--he_context_path", type=str, help="Path for the HE context file")
args = parser.parse_args()

flare.init()
Expand All @@ -84,6 +92,11 @@ def main():
event_local = data["event"]
time_local = data["time"]

# HE context
# In real-life application, HE context is prepared by secure provisioning
he_context_serial = read_data(args.he_context_path)
he_context = ts.context_from(he_context_serial)

while flare.is_running():
# receives global message from NVFlare
global_msg = flare.receive()
Expand All @@ -92,14 +105,7 @@ def main():

if curr_round == 1:
# First round:
# Get HE context from server
# Send max index back

# In real-life application, HE setup is done by secure provisioning
he_context_serial = global_msg.params["he_context"]
# bytes back to context object
he_context = ts.context_from(he_context_serial)

# Empty payload from server, send max index back
# Condense local data to histogram
event_table = survival_table_from_events(time_local, event_local)
hist_idx = event_table.index.values.astype(int)
Expand All @@ -108,7 +114,6 @@ def main():

# Send max to server
print(f"send max hist index for site = {flare.get_site_name()}")
# Send the results to server
model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL)
flare.send(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import logging
from typing import Dict

Expand All @@ -31,39 +32,37 @@


class KM(WF):
def __init__(self, min_clients: int):
def __init__(self, min_clients: int, he_context_path: str):
super(KM, self).__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.min_clients = min_clients
self.he_context_path = he_context_path
self.num_rounds = 3

def run(self):
he_context, max_idx_results = self.distribute_he_context_collect_max_idx()
max_idx_results = self.start_fl_collect_max_idx()
global_res = self.aggr_max_idx(max_idx_results)
enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res)
hist_obs_global, hist_cen_global = self.aggr_he_hist(he_context, enc_hist_results)
hist_obs_global, hist_cen_global = self.aggr_he_hist(enc_hist_results)
_ = self.distribute_global_hist(hist_obs_global, hist_cen_global)

def distribute_he_context_collect_max_idx(self):
self.logger.info("send kaplan-meier analysis command to all sites with HE context \n")

context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=4096, plain_modulus=1032193)
context_serial = context.serialize(save_secret_key=True)
# drop private key for server
context.make_context_public()
# payload data always needs to be wrapped into an FLModel
model = FLModel(params={"he_context": context_serial}, params_type=ParamsType.FULL)
def read_data(self, file_name: str):
with open(file_name, "rb") as f:
data = f.read()
return base64.b64decode(data)

def start_fl_collect_max_idx(self):
self.logger.info("send initial message to all sites to start FL \n")
msg_payload = {
MIN_RESPONSES: self.min_clients,
CURRENT_ROUND: 1,
NUM_ROUNDS: self.num_rounds,
START_ROUND: 1,
DATA: model,
DATA: {},
}

results = self.flare_comm.broadcast_and_wait(msg_payload)
return context, results
return results

def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]):
self.logger.info("aggregate max histogram index \n")
Expand Down Expand Up @@ -99,9 +98,13 @@ def distribute_max_idx_collect_enc_stats(self, result: int):
results = self.flare_comm.broadcast_and_wait(msg_payload)
return results

def aggr_he_hist(self, he_context, sag_result: Dict[str, Dict[str, FLModel]]):
def aggr_he_hist(self, sag_result: Dict[str, Dict[str, FLModel]]):
self.logger.info("aggregate histogram within HE \n")

# Load HE context
he_context_serial = self.read_data(self.he_context_path)
he_context = ts.context_from(he_context_serial)

if not sag_result:
raise RuntimeError("input is None or empty")

Expand Down
2 changes: 1 addition & 1 deletion examples/advanced/kaplan-meier-he/utils/prepare_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
62 changes: 62 additions & 0 deletions examples/advanced/kaplan-meier-he/utils/prepare_he_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import base64
import os

import tenseal as ts


def data_split_args_parser():
parser = argparse.ArgumentParser(description="Generate HE context")
parser.add_argument("--scheme", type=str, default="BFV", help="HE scheme, default is BFV")
parser.add_argument("--poly_modulus_degree", type=int, default=4096, help="Poly modulus degree, default is 4096")
parser.add_argument("--out_path", type=str, help="Output root path for HE context files for client and server")
return parser


def write_data(file_name: str, data: bytes):
data = base64.b64encode(data)
with open(file_name, "wb") as f:
f.write(data)


def main():
parser = data_split_args_parser()
args = parser.parse_args()
if args.scheme == "BFV":
scheme = ts.SCHEME_TYPE.BFV
# Generate HE context
context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree, plain_modulus=1032193)
elif args.scheme == "CKKS":
scheme = ts.SCHEME_TYPE.CKKS
# Generate HE context, CKKS does not need plain_modulus
context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree)
else:
raise ValueError("HE scheme not supported")

# Save HE context to file for client
if not os.path.exists(args.out_path):
os.makedirs(args.out_path, exist_ok=True)
context_serial = context.serialize(save_secret_key=True)
write_data(os.path.join(args.out_path, "he_context_client.txt"), context_serial)

# Save HE context to file for server
context_serial = context.serialize(save_secret_key=False)
write_data(os.path.join(args.out_path, "he_context_server.txt"), context_serial)


if __name__ == "__main__":
main()

0 comments on commit 8cb6d7d

Please sign in to comment.