diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 45f84a43e9..58cec6ebe2 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -28,7 +28,7 @@ The Flare Workflow Communicator API provides the functionality of customized mes Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) does not support [simulator mode](https://nvflare.readthedocs.io/en/main/getting_started.html), the main reason is that the HE context information (specs and keys) needs to be provisioned before initializing the federated job. For the same reason, it is not straightforward for users to try different HE schemes beyond our existing support for [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py). With the Flare Workflow Communicator API, such "proof of concept" experiment becomes easy (of course, secure provisioning is still the way to go for real-life federated applications). In this example, the federated analysis pipeline includes 3 rounds: -1. Server generate and distribute the HE context to clients, and remove the private key on server side. Again, this step is done by secure provisioning for real-life applications, but for simulator, we use this step to distribute the HE context. +1. Server send the simple start message without any payload. 2. Clients collect the information of the local maximum time bin number and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. 3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients. @@ -45,6 +45,11 @@ Then we run a 5-client federated job with simulator, begin with splitting and ge ```commandline python utils/prepare_data.py --out_path "/tmp/flare/dataset/km_data" ``` +Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but for simulator, we use this step to distribute the HE context. +```commandline +python utils/prepare_he_context.py --out_path "/tmp/flare/he_context" +``` + And we can run the federated job: ```commandline nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf index bb0ca87a8e..0704590617 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf @@ -6,7 +6,7 @@ app_script = "kaplan_meier_train.py" # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. - app_config = "--data_root /tmp/flare/dataset/km_data" + app_config = "--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" # Client Computing Executors. executors = [ diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf index de850e47bf..2ade684715 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf @@ -13,6 +13,7 @@ wf_class_path = "kaplan_meier_wf.KM", wf_args { min_clients = 5 + he_context_path = "/tmp/flare/he_context/he_context_server.txt" } } } diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py index a01e0a0d2f..401c26aaf1 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse +import base64 import json import os @@ -29,6 +30,12 @@ # Client code +def read_data(file_name: str): + with open(file_name, "rb") as f: + data = f.read() + return base64.b64decode(data) + + def details_save(kmf): # Get the survival function at all observed time points survival_function_at_all_times = kmf.survival_function_ @@ -71,6 +78,7 @@ def plot_and_save(kmf): def main(): parser = argparse.ArgumentParser(description="KM analysis") parser.add_argument("--data_root", type=str, help="Root path for data files") + parser.add_argument("--he_context_path", type=str, help="Path for the HE context file") args = parser.parse_args() flare.init() @@ -84,6 +92,11 @@ def main(): event_local = data["event"] time_local = data["time"] + # HE context + # In real-life application, HE context is prepared by secure provisioning + he_context_serial = read_data(args.he_context_path) + he_context = ts.context_from(he_context_serial) + while flare.is_running(): # receives global message from NVFlare global_msg = flare.receive() @@ -92,14 +105,7 @@ def main(): if curr_round == 1: # First round: - # Get HE context from server - # Send max index back - - # In real-life application, HE setup is done by secure provisioning - he_context_serial = global_msg.params["he_context"] - # bytes back to context object - he_context = ts.context_from(he_context_serial) - + # Empty payload from server, send max index back # Condense local data to histogram event_table = survival_table_from_events(time_local, event_local) hist_idx = event_table.index.values.astype(int) @@ -108,7 +114,6 @@ def main(): # Send max to server print(f"send max hist index for site = {flare.get_site_name()}") - # Send the results to server model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) flare.send(model) diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py index 9fc2575652..1ec5e372e3 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging from typing import Dict @@ -31,39 +32,37 @@ class KM(WF): - def __init__(self, min_clients: int): + def __init__(self, min_clients: int, he_context_path: str): super(KM, self).__init__() self.logger = logging.getLogger(self.__class__.__name__) self.min_clients = min_clients + self.he_context_path = he_context_path self.num_rounds = 3 def run(self): - he_context, max_idx_results = self.distribute_he_context_collect_max_idx() + max_idx_results = self.start_fl_collect_max_idx() global_res = self.aggr_max_idx(max_idx_results) enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res) - hist_obs_global, hist_cen_global = self.aggr_he_hist(he_context, enc_hist_results) + hist_obs_global, hist_cen_global = self.aggr_he_hist(enc_hist_results) _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) - def distribute_he_context_collect_max_idx(self): - self.logger.info("send kaplan-meier analysis command to all sites with HE context \n") - - context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=4096, plain_modulus=1032193) - context_serial = context.serialize(save_secret_key=True) - # drop private key for server - context.make_context_public() - # payload data always needs to be wrapped into an FLModel - model = FLModel(params={"he_context": context_serial}, params_type=ParamsType.FULL) + def read_data(self, file_name: str): + with open(file_name, "rb") as f: + data = f.read() + return base64.b64decode(data) + def start_fl_collect_max_idx(self): + self.logger.info("send initial message to all sites to start FL \n") msg_payload = { MIN_RESPONSES: self.min_clients, CURRENT_ROUND: 1, NUM_ROUNDS: self.num_rounds, START_ROUND: 1, - DATA: model, + DATA: {}, } results = self.flare_comm.broadcast_and_wait(msg_payload) - return context, results + return results def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): self.logger.info("aggregate max histogram index \n") @@ -99,9 +98,13 @@ def distribute_max_idx_collect_enc_stats(self, result: int): results = self.flare_comm.broadcast_and_wait(msg_payload) return results - def aggr_he_hist(self, he_context, sag_result: Dict[str, Dict[str, FLModel]]): + def aggr_he_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): self.logger.info("aggregate histogram within HE \n") + # Load HE context + he_context_serial = self.read_data(self.he_context_path) + he_context = ts.context_from(he_context_serial) + if not sag_result: raise RuntimeError("input is None or empty") diff --git a/examples/advanced/kaplan-meier-he/utils/prepare_data.py b/examples/advanced/kaplan-meier-he/utils/prepare_data.py index 951a93213c..66684a1b4b 100644 --- a/examples/advanced/kaplan-meier-he/utils/prepare_data.py +++ b/examples/advanced/kaplan-meier-he/utils/prepare_data.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py new file mode 100644 index 0000000000..ceedf4c9a4 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import base64 +import os + +import tenseal as ts + + +def data_split_args_parser(): + parser = argparse.ArgumentParser(description="Generate HE context") + parser.add_argument("--scheme", type=str, default="BFV", help="HE scheme, default is BFV") + parser.add_argument("--poly_modulus_degree", type=int, default=4096, help="Poly modulus degree, default is 4096") + parser.add_argument("--out_path", type=str, help="Output root path for HE context files for client and server") + return parser + + +def write_data(file_name: str, data: bytes): + data = base64.b64encode(data) + with open(file_name, "wb") as f: + f.write(data) + + +def main(): + parser = data_split_args_parser() + args = parser.parse_args() + if args.scheme == "BFV": + scheme = ts.SCHEME_TYPE.BFV + # Generate HE context + context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree, plain_modulus=1032193) + elif args.scheme == "CKKS": + scheme = ts.SCHEME_TYPE.CKKS + # Generate HE context, CKKS does not need plain_modulus + context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree) + else: + raise ValueError("HE scheme not supported") + + # Save HE context to file for client + if not os.path.exists(args.out_path): + os.makedirs(args.out_path, exist_ok=True) + context_serial = context.serialize(save_secret_key=True) + write_data(os.path.join(args.out_path, "he_context_client.txt"), context_serial) + + # Save HE context to file for server + context_serial = context.serialize(save_secret_key=False) + write_data(os.path.join(args.out_path, "he_context_server.txt"), context_serial) + + +if __name__ == "__main__": + main()