From 6fb7a5013713d03375bbb06e982e6b4ed06ead04 Mon Sep 17 00:00:00 2001 From: ZiyueXu77 Date: Thu, 4 Jan 2024 20:23:49 -0500 Subject: [PATCH 01/20] add example for mulitparty kaplan meier analysis with HE --- .../baseline_kaplan_meier_multi_party.py | 208 ++++++++++++++++++ .../km_he/app/config/config_fed_client.conf | 116 ++++++++++ .../km_he/app/config/config_fed_server.conf | 23 ++ .../jobs/km_he/app/custom/kaplan_meier_wf.py | 158 +++++++++++++ .../jobs/km_he/app/custom/km_train.py | 206 +++++++++++++++++ .../kaplan-meier-he/jobs/km_he/meta.conf | 7 + .../advanced/kaplan-meier-he/requirements.txt | 12 + 7 files changed, 730 insertions(+) create mode 100644 examples/advanced/kaplan-meier-he/baseline_kaplan_meier_multi_party.py create mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_client.conf create mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_server.conf create mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/kaplan_meier_wf.py create mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py create mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/meta.conf create mode 100644 examples/advanced/kaplan-meier-he/requirements.txt diff --git a/examples/advanced/kaplan-meier-he/baseline_kaplan_meier_multi_party.py b/examples/advanced/kaplan-meier-he/baseline_kaplan_meier_multi_party.py new file mode 100644 index 0000000000..acf1e93c8a --- /dev/null +++ b/examples/advanced/kaplan-meier-he/baseline_kaplan_meier_multi_party.py @@ -0,0 +1,208 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy + +import matplotlib.pyplot as plt +import numpy as np +import tenseal as ts +from lifelines import KaplanMeierFitter +from lifelines.utils import survival_table_from_events +from sksurv.datasets import load_veterans_lung_cancer + + +def args_parser(): + parser = argparse.ArgumentParser(description="Kaplan Meier Survival Analysis") + parser.add_argument("--num_of_clients", type=int, default=5, help="number of clients") + parser.add_argument("--he", action="store_true", help="use homomorphic encryption") + parser.add_argument( + "--output_curve_path", + type=str, + default="./km_curve_multi_party.png", + help="save path for the output curve", + ) + return parser + + +def prepare_data(num_of_clients: int, bin_days: int = 7): + # Load data + data_x, data_y = load_veterans_lung_cancer() + # Get total data count + total_data_num = data_x.shape[0] + print(f"Total data count: {total_data_num}") + # Get event and time + event = data_y["Status"] + time = data_y["Survival_in_days"] + # Categorize data to a bin, default is a week (7 days) + time = np.ceil(time / bin_days).astype(int) + # Shuffle data + idx = np.random.permutation(total_data_num) + # Split data to clients + event_clients = {} + time_clients = {} + for i in range(num_of_clients): + start = int(i * total_data_num / num_of_clients) + end = int((i + 1) * total_data_num / num_of_clients) + event_i = event[idx[start:end]] + time_i = time[idx[start:end]] + event_clients[i] = event_i + time_clients[i] = time_i + return event, time, event_clients, time_clients + + +def main(): + parser = args_parser() + args = parser.parse_args() + + # Set parameters + num_of_clients = args.num_of_clients + he = args.he + output_curve_path = args.output_curve_path + + # Generate data + event, time, event_clients, time_clients = prepare_data(num_of_clients) + + # Setup Plot + plt.figure() + if he: + total_subplot = 3 + else: + total_subplot = 2 + + # Setup tenseal context + # using BFV scheme since observations are integers + if he: + context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=4096, plain_modulus=1032193) + + # Fit and plot Kaplan Meier curve with lifelines + kmf = KaplanMeierFitter() + kmf.fit(time, event) + # Plot the Kaplan-Meier survival curve + plt.subplot(1, total_subplot, 1) + plt.title("Centralized") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + + # Distributed + # Stage 1 local: collect info and set histogram dict + event_table = {} + max_week_idx = [] + for client in range(num_of_clients): + # condense date to histogram + event_table[client] = survival_table_from_events(time_clients[client], event_clients[client]) + week_idx = event_table[client].index.values.astype(int) + # get the max week index + max_week_idx.append(max(week_idx)) + # Stage 1 global: get global histogram dict + hist_obs_global = {} + hist_sen_global = {} + # actual week as index, so plus 1 + max_week = max(max_week_idx) + 1 + for week in range(max_week): + hist_obs_global[week] = 0 + hist_sen_global[week] = 0 + if he: + # encrypt with tenseal + hist_obs_global_he = ts.bfv_vector(context, list(hist_obs_global.values())) + hist_sen_global_he = ts.bfv_vector(context, list(hist_sen_global.values())) + # Stage 2 local: convert local table to uniform histogram + hist_obs_local = {} + hist_sen_local = {} + hist_obs_local_he = {} + hist_sen_local_he = {} + for client in range(num_of_clients): + hist_obs_local[client] = copy.deepcopy(hist_obs_global) + hist_sen_local[client] = copy.deepcopy(hist_sen_global) + # assign values + week_idx = event_table[client].index.values.astype(int) + observed = event_table[client]["observed"].to_numpy() + sensored = event_table[client]["censored"].to_numpy() + for i in range(len(week_idx)): + hist_obs_local[client][week_idx[i]] = observed[i] + hist_sen_local[client][week_idx[i]] = sensored[i] + if he: + # encrypt with tenseal using BFV scheme since observations are integers + hist_obs_local_he[client] = ts.bfv_vector(context, list(hist_obs_local[client].values())) + hist_sen_local_he[client] = ts.bfv_vector(context, list(hist_sen_local[client].values())) + # Stage 2 global: sum up local histogram + for client in range(num_of_clients): + for week in range(max_week): + hist_obs_global[week] += hist_obs_local[client][week] + hist_sen_global[week] += hist_sen_local[client][week] + if he: + hist_obs_global_he += hist_obs_local_he[client] + hist_sen_global_he += hist_sen_local_he[client] + + # Stage 3 local: convert histogram to event list and fit K-M curve + # unfold histogram to event list + time_unfold = [] + event_unfold = [] + for i in range(max_week): + for j in range(hist_obs_global[i]): + time_unfold.append(i) + event_unfold.append(True) + for k in range(hist_sen_global[i]): + time_unfold.append(i) + event_unfold.append(False) + time_unfold = np.array(time_unfold) + event_unfold = np.array(event_unfold) + # Fit the survival data with lifelines + kmf.fit(time_unfold, event_unfold) + # Plot the Kaplan-Meier survival curve + plt.subplot(1, total_subplot, 2) + plt.title("Federated") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + + if he: + # decrypt with tenseal + hist_obs_global_he = hist_obs_global_he.decrypt() + hist_sen_global_he = hist_sen_global_he.decrypt() + # unfold histogram to event list + time_unfold = [] + event_unfold = [] + for i in range(max_week): + for j in range(hist_obs_global_he[i]): + time_unfold.append(i) + event_unfold.append(True) + for k in range(hist_sen_global_he[i]): + time_unfold.append(i) + event_unfold.append(False) + time_unfold = np.array(time_unfold) + event_unfold = np.array(event_unfold) + # Fit the survival data with lifelines + kmf.fit(time_unfold, event_unfold) + # Plot the Kaplan-Meier survival curve + plt.subplot(1, total_subplot, 3) + plt.title("Federated HE") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + + # Save curve + plt.tight_layout() + plt.savefig(output_curve_path) + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_client.conf b/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_client.conf new file mode 100644 index 0000000000..9de6ad8d7c --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_client.conf @@ -0,0 +1,116 @@ +{ + # version of the configuration + format_version = 2 + + # This is the application script which will be invoked. Client can replace this script with user's own training script. + app_script = "km_train.py" + + # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. + app_config = "" + + # Client Computing Executors. + executors = [ + { + # tasks the executors are defined to handle + tasks = ["train"] + + # This particular executor + executor { + + # This is an executor for Client API. The underline data exchange is using Pipe. + path = "nvflare.app_opt.pt.client_api_launcher_executor.ClientAPILauncherExecutor" + + args { + # launcher_id is used to locate the Launcher object in "components" + launcher_id = "launcher" + + # pipe_id is used to locate the Pipe object in "components" + pipe_id = "pipe" + + # Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. + # Please refer to the class docstring for all available arguments + heartbeat_timeout = 60 + + # format of the exchange parameters + params_exchange_format = "raw" + + # if the transfer_type is FULL, then it will be sent directly + # if the transfer_type is DIFF, then we will calculate the + # difference VS received parameters and send the difference + params_transfer_type = "FULL" + + # if train_with_evaluation is true, the executor will expect + # the custom code need to send back both the trained parameters and the evaluation metric + # otherwise only trained parameters are expected + train_with_evaluation = false + } + } + } + ], + + # this defined an array of task data filters. If provided, it will control the data from server controller to client executor + task_data_filters = [] + + # this defined an array of task result filters. If provided, it will control the result from client executor to server controller + task_result_filters = [] + + components = [ + { + # component id is "launcher" + id = "launcher" + + # the class path of this component + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + + args { + # the launcher will invoke the script + script = "python3 custom/{app_script} {app_config} " + # if launch_once is true, the SubprocessLauncher will launch once for the whole job + # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server + launch_once = true + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + }, + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + # how fast should it read from the peer + read_interval = 0.1 + } + }, + { + # we use this component so the client api `flare.init()` can get required information + id = "config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = ["metric_relay"] + } + } + ] +} diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_server.conf b/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_server.conf new file mode 100644 index 0000000000..618bb4b0b5 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_server.conf @@ -0,0 +1,23 @@ +{ + # version of the configuration + format_version = 2 + task_data_filters =[] + task_result_filters = [] + + workflows = [ + { + id = "km" + path = "nvflare.app_common.workflows.wf_controller.WFController" + args { + task_name = "train" + wf_class_path = "kaplan_meier_wf.KM", + wf_args { + min_clients = 2 + } + } + } + ] + + components = [] + +} diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/kaplan_meier_wf.py new file mode 100644 index 0000000000..b775afaaa7 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/kaplan_meier_wf.py @@ -0,0 +1,158 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict + +import tenseal as ts + +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType +from nvflare.app_common.workflows.wf_comm.wf_comm_api_spec import ( + CURRENT_ROUND, + DATA, + MIN_RESPONSES, + NUM_ROUNDS, + START_ROUND, +) +from nvflare.app_common.workflows.wf_comm.wf_spec import WF + +# Controller Workflow + + +class KM(WF): + def __init__(self, min_clients: int): + super(KM, self).__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.min_clients = min_clients + self.num_rounds = 3 + + def run(self): + he_context, max_idx_results = self.distribute_he_context_collect_max_idx() + global_res = self.aggr_max_idx(max_idx_results) + enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res) + hist_obs_global, hist_cen_global = self.aggr_he_hist(he_context, enc_hist_results) + _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) + + def distribute_he_context_collect_max_idx(self): + self.logger.info("send kaplan-meier analysis command to all sites with HE context \n") + + context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=4096, plain_modulus=1032193) + context_serial = context.serialize(save_secret_key=True) + # drop private key for server + context.make_context_public() + # payload data always needs to be wrapped into an FLModel + model = FLModel(params={"he_context": context_serial}, params_type=ParamsType.FULL) + + msg_payload = { + MIN_RESPONSES: self.min_clients, + CURRENT_ROUND: 1, + NUM_ROUNDS: self.num_rounds, + START_ROUND: 1, + DATA: model, + } + + results = self.flare_comm.broadcast_and_wait(msg_payload) + return context, results + + def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate max histogram index \n") + + if not sag_result: + raise RuntimeError("input is None or empty") + + task_name, task_result = next(iter(sag_result.items())) + + if not task_result: + raise RuntimeError("task_result None or empty ") + + max_idx_global = [] + for site, fl_model in task_result.items(): + max_idx = fl_model.params["max_idx"] + print(max_idx) + max_idx_global.append(max_idx) + # actual time point as index, so plus 1 for storage + return max(max_idx_global) + 1 + + def distribute_max_idx_collect_enc_stats(self, result: int): + self.logger.info("send global max_index to all sites \n") + + model = FLModel(params={"max_idx_global": result}, params_type=ParamsType.FULL) + + msg_payload = { + MIN_RESPONSES: self.min_clients, + CURRENT_ROUND: 2, + NUM_ROUNDS: self.num_rounds, + START_ROUND: 1, + DATA: model, + } + + results = self.flare_comm.broadcast_and_wait(msg_payload) + return results + + def aggr_he_hist(self, he_context, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate histogram within HE \n") + + if not sag_result: + raise RuntimeError("input is None or empty") + + task_name, task_result = next(iter(sag_result.items())) + + if not task_result: + raise RuntimeError("task_result None or empty ") + + hist_obs_global = None + hist_cen_global = None + for site, fl_model in task_result.items(): + hist_obs_he_serial = fl_model.params["hist_obs"] + hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial) + hist_cen_he_serial = fl_model.params["hist_cen"] + hist_cen_he = ts.bfv_vector_from(he_context, hist_cen_he_serial) + + if not hist_obs_global: + print(f"assign global hist with result from {site}") + hist_obs_global = hist_obs_he + else: + print(f"add to global hist with result from {site}") + hist_obs_global += hist_obs_he + + if not hist_cen_global: + print(f"assign global hist with result from {site}") + hist_cen_global = hist_cen_he + else: + print(f"add to global hist with result from {site}") + hist_cen_global += hist_cen_he + + # return the two accumulated vectors, serialized for transmission + hist_obs_global_serial = hist_obs_global.serialize() + hist_cen_global_serial = hist_cen_global.serialize() + return hist_obs_global_serial, hist_cen_global_serial + + def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial): + self.logger.info("send global accumulated histograms within HE to all sites \n") + + model = FLModel( + params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, + params_type=ParamsType.FULL, + ) + + msg_payload = { + MIN_RESPONSES: self.min_clients, + CURRENT_ROUND: 3, + NUM_ROUNDS: self.num_rounds, + START_ROUND: 1, + DATA: model, + } + + results = self.flare_comm.broadcast_and_wait(msg_payload) + return results diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py b/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py new file mode 100644 index 0000000000..add014386f --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py @@ -0,0 +1,206 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import matplotlib.pyplot as plt +import numpy as np +import tenseal as ts +from lifelines import KaplanMeierFitter +from lifelines.utils import survival_table_from_events +from sksurv.datasets import load_veterans_lung_cancer + +# (1) import nvflare client API +import nvflare.client as flare +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType + +# Client training code + +np.random.seed(77) + + +def prepare_data(num_of_clients: int = 2, bin_days: int = 7): + # Load data + data_x, data_y = load_veterans_lung_cancer() + # Get total data count + total_data_num = data_x.shape[0] + print(f"Total data count: {total_data_num}") + # Get event and time + event = data_y["Status"] + time = data_y["Survival_in_days"] + # Categorize data to a bin, default is a week (7 days) + time = np.ceil(time / bin_days).astype(int) + # Shuffle data + idx = np.random.permutation(total_data_num) + # Split data to clients + event_clients = {} + time_clients = {} + for i in range(num_of_clients): + start = int(i * total_data_num / num_of_clients) + end = int((i + 1) * total_data_num / num_of_clients) + event_i = event[idx[start:end]] + time_i = time[idx[start:end]] + event_clients["site-" + str(i + 1)] = event_i + time_clients["site-" + str(i + 1)] = time_i + return event_clients, time_clients + + +def save(result: dict): + file_path = os.path.join(os.getcwd(), "km_global.json") + print(f"save the result to {file_path} \n") + with open(file_path, "w") as json_file: + json.dump(result, json_file, indent=4) + + +def main(): + flare.init() + + site_name = flare.get_site_name() + print(f"Kaplan-meier analysis for {site_name}") + + # get local data + event_clients, time_clients = prepare_data() + event_local = event_clients[site_name] + time_local = time_clients[site_name] + + while flare.is_running(): + # receives global message from NVFlare + global_msg = flare.receive() + curr_round = global_msg.current_round + print(f"current_round={curr_round}") + + if curr_round == 1: + # First round: + # Get HE context from server + # Send max index back + + # In real-life application, HE setup is done by secure provisioning + he_context_serial = global_msg.params["he_context"] + # bytes back to context object + he_context = ts.context_from(he_context_serial) + + # Condense local data to histogram + event_table = survival_table_from_events(time_local, event_local) + hist_idx = event_table.index.values.astype(int) + # Get the max index to be synced globally + max_hist_idx = max(hist_idx) + + # Send max to server + print(f"send max hist index for site = {flare.get_site_name()}") + # Send the results to server + model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) + flare.send(model) + + elif curr_round == 2: + # Second round, get global max index + # Organize local histogram and encrypt + max_idx_global = global_msg.params["max_idx_global"] + print("Global Max Idx") + print(max_idx_global) + # Convert local table to uniform histogram + hist_obs = {} + hist_cen = {} + for idx in range(max_idx_global): + hist_obs[idx] = 0 + hist_cen[idx] = 0 + # assign values + idx = event_table.index.values.astype(int) + observed = event_table["observed"].to_numpy() + censored = event_table["censored"].to_numpy() + for i in range(len(idx)): + hist_obs[idx[i]] = observed[i] + hist_cen[idx[i]] = censored[i] + # Encrypt with tenseal using BFV scheme since observations are integers + hist_obs_he = ts.bfv_vector(he_context, list(hist_obs.values())) + hist_cen_he = ts.bfv_vector(he_context, list(hist_cen.values())) + # Serialize for transmission + hist_obs_he_serial = hist_obs_he.serialize() + hist_cen_he_serial = hist_cen_he.serialize() + # Send encrypted histograms to server + response = FLModel( + params={"hist_obs": hist_obs_he_serial, "hist_cen": hist_cen_he_serial}, params_type=ParamsType.FULL + ) + flare.send(response) + + elif curr_round == 3: + # Get global histograms + hist_obs_global_serial = global_msg.params["hist_obs_global"] + hist_cen_global_serial = global_msg.params["hist_cen_global"] + # Deserialize + hist_obs_global = ts.bfv_vector_from(he_context, hist_obs_global_serial) + hist_cen_global = ts.bfv_vector_from(he_context, hist_cen_global_serial) + # Decrypt + hist_obs_global = hist_obs_global.decrypt() + hist_cen_global = hist_cen_global.decrypt() + # Unfold histogram to event list + time_unfold = [] + event_unfold = [] + for i in range(max_idx_global): + for j in range(hist_obs_global[i]): + time_unfold.append(i) + event_unfold.append(True) + for k in range(hist_cen_global[i]): + time_unfold.append(i) + event_unfold.append(False) + time_unfold = np.array(time_unfold) + event_unfold = np.array(event_unfold) + + # Perform Kaplan-Meier analysis on global aggregated information + # Create a Kaplan-Meier estimator + kmf = KaplanMeierFitter() + + # Fit the model + kmf.fit(durations=time_unfold, event_observed=event_unfold) + + # Plot and save KM curve + plt.figure() + plt.title("Federated HE") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + plt.tight_layout() + plt.savefig(os.path.join(os.getcwd(), "km_curve.png")) + + # Save global result to a json file + # Get the survival function at all observed time points + survival_function_at_all_times = kmf.survival_function_ + # Get the timeline (time points) + timeline = survival_function_at_all_times.index.values + # Get the KM estimate + km_estimate = survival_function_at_all_times["KM_estimate"].values + # Get the event count at each time point + event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events + # Get the survival rate at each time point (using the 1st column of the survival function) + survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values + # Return the results + results = { + "timeline": timeline.tolist(), + "km_estimate": km_estimate.tolist(), + "event_count": event_count.tolist(), + "survival_rate": survival_rate.tolist(), + } + save(results) + + # Send a simple response to server + response = FLModel(params={}, params_type=ParamsType.FULL) + flare.send(response) + + print(f"finish send for {site_name}, complete") + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/meta.conf b/examples/advanced/kaplan-meier-he/jobs/km_he/meta.conf new file mode 100644 index 0000000000..5c81903a41 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/km_he/meta.conf @@ -0,0 +1,7 @@ +{ + name = "fl_km" + deploy_map { + app = ["@ALL"] + } + min_clients = 2 +} diff --git a/examples/advanced/kaplan-meier-he/requirements.txt b/examples/advanced/kaplan-meier-he/requirements.txt new file mode 100644 index 0000000000..b95c29c97d --- /dev/null +++ b/examples/advanced/kaplan-meier-he/requirements.txt @@ -0,0 +1,12 @@ +lifelines + + + +git clone --recursive https://github.com/OpenMined/TenSEAL.git +cd TenSEAL/ +git submodule init +git submodule update +pip install . + + +scikit-survival \ No newline at end of file From e54e6b50e57ae2dac793894a2a76ce9012d94c49 Mon Sep 17 00:00:00 2001 From: ZiyueXu77 Date: Thu, 4 Jan 2024 20:41:07 -0500 Subject: [PATCH 02/20] update requirements --- examples/advanced/kaplan-meier-he/requirements.txt | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/requirements.txt b/examples/advanced/kaplan-meier-he/requirements.txt index b95c29c97d..6b15006556 100644 --- a/examples/advanced/kaplan-meier-he/requirements.txt +++ b/examples/advanced/kaplan-meier-he/requirements.txt @@ -1,12 +1,3 @@ lifelines - - - -git clone --recursive https://github.com/OpenMined/TenSEAL.git -cd TenSEAL/ -git submodule init -git submodule update -pip install . - - +tenseal scikit-survival \ No newline at end of file From 87ecf77194c2a95456d158a2931e95b0d80667cf Mon Sep 17 00:00:00 2001 From: ZiyueXu77 Date: Fri, 5 Jan 2024 13:45:45 -0500 Subject: [PATCH 03/20] update baseline script, remove complex settings and keep basic only --- .../kaplan-meier-he/baseline_kaplan_meier.py | 72 ++++++ .../baseline_kaplan_meier_multi_party.py | 208 ------------------ .../jobs/km_he/app/custom/km_train.py | 2 +- 3 files changed, 73 insertions(+), 209 deletions(-) create mode 100644 examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py delete mode 100644 examples/advanced/kaplan-meier-he/baseline_kaplan_meier_multi_party.py diff --git a/examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py b/examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py new file mode 100644 index 0000000000..79e7b8052f --- /dev/null +++ b/examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py @@ -0,0 +1,72 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import matplotlib.pyplot as plt +import numpy as np +from lifelines import KaplanMeierFitter +from sksurv.datasets import load_veterans_lung_cancer + + +def args_parser(): + parser = argparse.ArgumentParser(description="Kaplan Meier Survival Analysis Baseline") + parser.add_argument( + "--output_curve_path", + type=str, + default="./km_curve_baseline.png", + help="save path for the output curve", + ) + return parser + + +def prepare_data(bin_days: int = 7): + data_x, data_y = load_veterans_lung_cancer() + total_data_num = data_x.shape[0] + print(f"Total data count: {total_data_num}") + event = data_y["Status"] + time = data_y["Survival_in_days"] + # Categorize data to a bin, default is a week (7 days) + time = np.ceil(time / bin_days).astype(int) + return event, time + + +def main(): + parser = args_parser() + args = parser.parse_args() + + # Set parameters + output_curve_path = args.output_curve_path + + # Generate data + event, time = prepare_data() + + # Fit and plot Kaplan Meier curve with lifelines + kmf = KaplanMeierFitter() + # Fit the survival data + kmf.fit(time, event) + # Plot and save the Kaplan-Meier survival curve + plt.figure() + plt.title("Baseline") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + plt.tight_layout() + plt.savefig(output_curve_path) + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/kaplan-meier-he/baseline_kaplan_meier_multi_party.py b/examples/advanced/kaplan-meier-he/baseline_kaplan_meier_multi_party.py deleted file mode 100644 index acf1e93c8a..0000000000 --- a/examples/advanced/kaplan-meier-he/baseline_kaplan_meier_multi_party.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import copy - -import matplotlib.pyplot as plt -import numpy as np -import tenseal as ts -from lifelines import KaplanMeierFitter -from lifelines.utils import survival_table_from_events -from sksurv.datasets import load_veterans_lung_cancer - - -def args_parser(): - parser = argparse.ArgumentParser(description="Kaplan Meier Survival Analysis") - parser.add_argument("--num_of_clients", type=int, default=5, help="number of clients") - parser.add_argument("--he", action="store_true", help="use homomorphic encryption") - parser.add_argument( - "--output_curve_path", - type=str, - default="./km_curve_multi_party.png", - help="save path for the output curve", - ) - return parser - - -def prepare_data(num_of_clients: int, bin_days: int = 7): - # Load data - data_x, data_y = load_veterans_lung_cancer() - # Get total data count - total_data_num = data_x.shape[0] - print(f"Total data count: {total_data_num}") - # Get event and time - event = data_y["Status"] - time = data_y["Survival_in_days"] - # Categorize data to a bin, default is a week (7 days) - time = np.ceil(time / bin_days).astype(int) - # Shuffle data - idx = np.random.permutation(total_data_num) - # Split data to clients - event_clients = {} - time_clients = {} - for i in range(num_of_clients): - start = int(i * total_data_num / num_of_clients) - end = int((i + 1) * total_data_num / num_of_clients) - event_i = event[idx[start:end]] - time_i = time[idx[start:end]] - event_clients[i] = event_i - time_clients[i] = time_i - return event, time, event_clients, time_clients - - -def main(): - parser = args_parser() - args = parser.parse_args() - - # Set parameters - num_of_clients = args.num_of_clients - he = args.he - output_curve_path = args.output_curve_path - - # Generate data - event, time, event_clients, time_clients = prepare_data(num_of_clients) - - # Setup Plot - plt.figure() - if he: - total_subplot = 3 - else: - total_subplot = 2 - - # Setup tenseal context - # using BFV scheme since observations are integers - if he: - context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=4096, plain_modulus=1032193) - - # Fit and plot Kaplan Meier curve with lifelines - kmf = KaplanMeierFitter() - kmf.fit(time, event) - # Plot the Kaplan-Meier survival curve - plt.subplot(1, total_subplot, 1) - plt.title("Centralized") - kmf.plot_survival_function() - plt.ylim(0, 1) - plt.ylabel("prob") - plt.xlabel("time") - plt.legend("", frameon=False) - - # Distributed - # Stage 1 local: collect info and set histogram dict - event_table = {} - max_week_idx = [] - for client in range(num_of_clients): - # condense date to histogram - event_table[client] = survival_table_from_events(time_clients[client], event_clients[client]) - week_idx = event_table[client].index.values.astype(int) - # get the max week index - max_week_idx.append(max(week_idx)) - # Stage 1 global: get global histogram dict - hist_obs_global = {} - hist_sen_global = {} - # actual week as index, so plus 1 - max_week = max(max_week_idx) + 1 - for week in range(max_week): - hist_obs_global[week] = 0 - hist_sen_global[week] = 0 - if he: - # encrypt with tenseal - hist_obs_global_he = ts.bfv_vector(context, list(hist_obs_global.values())) - hist_sen_global_he = ts.bfv_vector(context, list(hist_sen_global.values())) - # Stage 2 local: convert local table to uniform histogram - hist_obs_local = {} - hist_sen_local = {} - hist_obs_local_he = {} - hist_sen_local_he = {} - for client in range(num_of_clients): - hist_obs_local[client] = copy.deepcopy(hist_obs_global) - hist_sen_local[client] = copy.deepcopy(hist_sen_global) - # assign values - week_idx = event_table[client].index.values.astype(int) - observed = event_table[client]["observed"].to_numpy() - sensored = event_table[client]["censored"].to_numpy() - for i in range(len(week_idx)): - hist_obs_local[client][week_idx[i]] = observed[i] - hist_sen_local[client][week_idx[i]] = sensored[i] - if he: - # encrypt with tenseal using BFV scheme since observations are integers - hist_obs_local_he[client] = ts.bfv_vector(context, list(hist_obs_local[client].values())) - hist_sen_local_he[client] = ts.bfv_vector(context, list(hist_sen_local[client].values())) - # Stage 2 global: sum up local histogram - for client in range(num_of_clients): - for week in range(max_week): - hist_obs_global[week] += hist_obs_local[client][week] - hist_sen_global[week] += hist_sen_local[client][week] - if he: - hist_obs_global_he += hist_obs_local_he[client] - hist_sen_global_he += hist_sen_local_he[client] - - # Stage 3 local: convert histogram to event list and fit K-M curve - # unfold histogram to event list - time_unfold = [] - event_unfold = [] - for i in range(max_week): - for j in range(hist_obs_global[i]): - time_unfold.append(i) - event_unfold.append(True) - for k in range(hist_sen_global[i]): - time_unfold.append(i) - event_unfold.append(False) - time_unfold = np.array(time_unfold) - event_unfold = np.array(event_unfold) - # Fit the survival data with lifelines - kmf.fit(time_unfold, event_unfold) - # Plot the Kaplan-Meier survival curve - plt.subplot(1, total_subplot, 2) - plt.title("Federated") - kmf.plot_survival_function() - plt.ylim(0, 1) - plt.ylabel("prob") - plt.xlabel("time") - plt.legend("", frameon=False) - - if he: - # decrypt with tenseal - hist_obs_global_he = hist_obs_global_he.decrypt() - hist_sen_global_he = hist_sen_global_he.decrypt() - # unfold histogram to event list - time_unfold = [] - event_unfold = [] - for i in range(max_week): - for j in range(hist_obs_global_he[i]): - time_unfold.append(i) - event_unfold.append(True) - for k in range(hist_sen_global_he[i]): - time_unfold.append(i) - event_unfold.append(False) - time_unfold = np.array(time_unfold) - event_unfold = np.array(event_unfold) - # Fit the survival data with lifelines - kmf.fit(time_unfold, event_unfold) - # Plot the Kaplan-Meier survival curve - plt.subplot(1, total_subplot, 3) - plt.title("Federated HE") - kmf.plot_survival_function() - plt.ylim(0, 1) - plt.ylabel("prob") - plt.xlabel("time") - plt.legend("", frameon=False) - - # Save curve - plt.tight_layout() - plt.savefig(output_curve_path) - - -if __name__ == "__main__": - main() diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py b/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py index add014386f..1ef907c9f7 100644 --- a/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py +++ b/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py @@ -164,7 +164,7 @@ def main(): # Fit the model kmf.fit(durations=time_unfold, event_observed=event_unfold) - # Plot and save KM curve + # Plot and save the Kaplan-Meier survival curve plt.figure() plt.title("Federated HE") kmf.plot_survival_function() From 6031902e3d389e2405c54646682e07e48d4d526a Mon Sep 17 00:00:00 2001 From: ZiyueXu77 Date: Fri, 5 Jan 2024 15:04:27 -0500 Subject: [PATCH 04/20] add readme with details --- examples/advanced/kaplan-meier-he/README.md | 52 +++++ .../km_he/app/config/config_fed_client.conf | 116 ---------- .../km_he/app/config/config_fed_server.conf | 23 -- .../jobs/km_he/app/custom/kaplan_meier_wf.py | 158 -------------- .../jobs/km_he/app/custom/km_train.py | 206 ------------------ .../kaplan-meier-he/jobs/km_he/meta.conf | 7 - 6 files changed, 52 insertions(+), 510 deletions(-) create mode 100644 examples/advanced/kaplan-meier-he/README.md delete mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_client.conf delete mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_server.conf delete mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/kaplan_meier_wf.py delete mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py delete mode 100644 examples/advanced/kaplan-meier-he/jobs/km_he/meta.conf diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md new file mode 100644 index 0000000000..026b7f3569 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/README.md @@ -0,0 +1,52 @@ +# Secure Federated Kaplan-Meier Analysis via Homomorphic Encryption + +This example illustrates two features: +* How to perform Kaplan-Meirer survival analysis in federated setting securely via Homomorphic Encryption (HE). +* How to use the Flare Workflow Communicator API to contract a workflow to facilitate HE under simulator mode. + +## Secure Multi-party Kaplan-Meier Analysis +Kaplan-Meier survival analysis is a one-shot (non-iterative) analysis performed on a list of events and their corresponding time. In this example, we use [lifelines](https://zenodo.org/records/10456828) to perform this analysis. + +Essentially, the estimator needs to get access to the event list, and under the setting of federated analysis, the aggregated event list from all participants. + +However, this poses a data security concern - by sharing the event list, the raw data can be exposed to external parties, which break the core value of federated analysis. + +Therefore, we would like to design a secure mechanism to enable collaborative Kaplan-Meier analysis without the risk of exposing any raw information from a certain participant (at server end). This is achieved by two techniques: + +- Condense the raw event list to two histograms (one for observed events and the other for censored event) binned at certain interval (e.g. a week), such that events happened within the same bin from different participants can be aggregated and will not be distinguishable for the final aggregated histograms. +- The local histograms will be encrypted as one single vector before sending to server, and the global aggregation operation at server side will be performed entirely within encryption space with HE. + +With these two settings, the server will have no access to any knowledge regarding local submissions, and participants will only receive global aggregated histograms that will not contain distinguishable information regarding any individual participants (client number >= 3 - if only two participants, one can infer the other party's info by subtracting its own histograms). + +The final Kaplan-Meier survival analysis will be performed locally on the global aggregated event list, recovered from global histograms. + + +## Simulated HE Analysis via FLARE Workflow Communicator API + +The Flare Workflow Communicator API provides the functionality of customized message payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme. + +Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) does not support [simulator mode](https://nvflare.readthedocs.io/en/main/getting_started.html), the main reason is that the HE context information (specs and keys) needs to be provisioned before initializing the federated job. For the same reason, it is not straightforward for users to try different HE schemes beyond our existing support for [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py). + +With the Flare Workflow Communicator API, such "proof of concept" experiment becomes easy (of course, secure provisioning is still the way to go for real-life federated applications). In this example, the federated analysis pipeline includes 3 rounds: +1. Server generate and distribute the HE context to clients, and remove the private key on server side. Again, this step is done by secure provisioning for real-life applications, but for simulator, we use this step to distribute the HE context. +2. Clients collect the information of the local maximum time bin number and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. +3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients. + +After Round 3, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information. + +## Run the job +We first run a baseline analysis with full event information: +```commandline +python baseline_kaplan_meier.py +``` +By default, this will generate a KM curve image `km_curve_baseline.png` under the current working directory. + +Then we run the federated job with simulator +```commandline +nvflare simulator -w workspace_km_he -n 2 -t 2 jobs/kaplan-meier-he +``` +By default, this will generate a KM curve image `km_curve_fl.png` under each client's directory. + +## Display Result + +By comparing the two curves, we can observe that the two are identical. \ No newline at end of file diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_client.conf b/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_client.conf deleted file mode 100644 index 9de6ad8d7c..0000000000 --- a/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_client.conf +++ /dev/null @@ -1,116 +0,0 @@ -{ - # version of the configuration - format_version = 2 - - # This is the application script which will be invoked. Client can replace this script with user's own training script. - app_script = "km_train.py" - - # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. - app_config = "" - - # Client Computing Executors. - executors = [ - { - # tasks the executors are defined to handle - tasks = ["train"] - - # This particular executor - executor { - - # This is an executor for Client API. The underline data exchange is using Pipe. - path = "nvflare.app_opt.pt.client_api_launcher_executor.ClientAPILauncherExecutor" - - args { - # launcher_id is used to locate the Launcher object in "components" - launcher_id = "launcher" - - # pipe_id is used to locate the Pipe object in "components" - pipe_id = "pipe" - - # Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. - # Please refer to the class docstring for all available arguments - heartbeat_timeout = 60 - - # format of the exchange parameters - params_exchange_format = "raw" - - # if the transfer_type is FULL, then it will be sent directly - # if the transfer_type is DIFF, then we will calculate the - # difference VS received parameters and send the difference - params_transfer_type = "FULL" - - # if train_with_evaluation is true, the executor will expect - # the custom code need to send back both the trained parameters and the evaluation metric - # otherwise only trained parameters are expected - train_with_evaluation = false - } - } - } - ], - - # this defined an array of task data filters. If provided, it will control the data from server controller to client executor - task_data_filters = [] - - # this defined an array of task result filters. If provided, it will control the result from client executor to server controller - task_result_filters = [] - - components = [ - { - # component id is "launcher" - id = "launcher" - - # the class path of this component - path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" - - args { - # the launcher will invoke the script - script = "python3 custom/{app_script} {app_config} " - # if launch_once is true, the SubprocessLauncher will launch once for the whole job - # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server - launch_once = true - } - } - { - id = "pipe" - path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" - args { - mode = "PASSIVE" - site_name = "{SITE_NAME}" - token = "{JOB_ID}" - root_url = "{ROOT_URL}" - secure_mode = "{SECURE_MODE}" - workspace_dir = "{WORKSPACE}" - } - } - { - id = "metrics_pipe" - path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" - args { - mode = "PASSIVE" - site_name = "{SITE_NAME}" - token = "{JOB_ID}" - root_url = "{ROOT_URL}" - secure_mode = "{SECURE_MODE}" - workspace_dir = "{WORKSPACE}" - } - }, - { - id = "metric_relay" - path = "nvflare.app_common.widgets.metric_relay.MetricRelay" - args { - pipe_id = "metrics_pipe" - event_type = "fed.analytix_log_stats" - # how fast should it read from the peer - read_interval = 0.1 - } - }, - { - # we use this component so the client api `flare.init()` can get required information - id = "config_preparer" - path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" - args { - component_ids = ["metric_relay"] - } - } - ] -} diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_server.conf b/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_server.conf deleted file mode 100644 index 618bb4b0b5..0000000000 --- a/examples/advanced/kaplan-meier-he/jobs/km_he/app/config/config_fed_server.conf +++ /dev/null @@ -1,23 +0,0 @@ -{ - # version of the configuration - format_version = 2 - task_data_filters =[] - task_result_filters = [] - - workflows = [ - { - id = "km" - path = "nvflare.app_common.workflows.wf_controller.WFController" - args { - task_name = "train" - wf_class_path = "kaplan_meier_wf.KM", - wf_args { - min_clients = 2 - } - } - } - ] - - components = [] - -} diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/kaplan_meier_wf.py deleted file mode 100644 index b775afaaa7..0000000000 --- a/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/kaplan_meier_wf.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Dict - -import tenseal as ts - -from nvflare.app_common.abstract.fl_model import FLModel, ParamsType -from nvflare.app_common.workflows.wf_comm.wf_comm_api_spec import ( - CURRENT_ROUND, - DATA, - MIN_RESPONSES, - NUM_ROUNDS, - START_ROUND, -) -from nvflare.app_common.workflows.wf_comm.wf_spec import WF - -# Controller Workflow - - -class KM(WF): - def __init__(self, min_clients: int): - super(KM, self).__init__() - self.logger = logging.getLogger(self.__class__.__name__) - self.min_clients = min_clients - self.num_rounds = 3 - - def run(self): - he_context, max_idx_results = self.distribute_he_context_collect_max_idx() - global_res = self.aggr_max_idx(max_idx_results) - enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res) - hist_obs_global, hist_cen_global = self.aggr_he_hist(he_context, enc_hist_results) - _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) - - def distribute_he_context_collect_max_idx(self): - self.logger.info("send kaplan-meier analysis command to all sites with HE context \n") - - context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=4096, plain_modulus=1032193) - context_serial = context.serialize(save_secret_key=True) - # drop private key for server - context.make_context_public() - # payload data always needs to be wrapped into an FLModel - model = FLModel(params={"he_context": context_serial}, params_type=ParamsType.FULL) - - msg_payload = { - MIN_RESPONSES: self.min_clients, - CURRENT_ROUND: 1, - NUM_ROUNDS: self.num_rounds, - START_ROUND: 1, - DATA: model, - } - - results = self.flare_comm.broadcast_and_wait(msg_payload) - return context, results - - def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): - self.logger.info("aggregate max histogram index \n") - - if not sag_result: - raise RuntimeError("input is None or empty") - - task_name, task_result = next(iter(sag_result.items())) - - if not task_result: - raise RuntimeError("task_result None or empty ") - - max_idx_global = [] - for site, fl_model in task_result.items(): - max_idx = fl_model.params["max_idx"] - print(max_idx) - max_idx_global.append(max_idx) - # actual time point as index, so plus 1 for storage - return max(max_idx_global) + 1 - - def distribute_max_idx_collect_enc_stats(self, result: int): - self.logger.info("send global max_index to all sites \n") - - model = FLModel(params={"max_idx_global": result}, params_type=ParamsType.FULL) - - msg_payload = { - MIN_RESPONSES: self.min_clients, - CURRENT_ROUND: 2, - NUM_ROUNDS: self.num_rounds, - START_ROUND: 1, - DATA: model, - } - - results = self.flare_comm.broadcast_and_wait(msg_payload) - return results - - def aggr_he_hist(self, he_context, sag_result: Dict[str, Dict[str, FLModel]]): - self.logger.info("aggregate histogram within HE \n") - - if not sag_result: - raise RuntimeError("input is None or empty") - - task_name, task_result = next(iter(sag_result.items())) - - if not task_result: - raise RuntimeError("task_result None or empty ") - - hist_obs_global = None - hist_cen_global = None - for site, fl_model in task_result.items(): - hist_obs_he_serial = fl_model.params["hist_obs"] - hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial) - hist_cen_he_serial = fl_model.params["hist_cen"] - hist_cen_he = ts.bfv_vector_from(he_context, hist_cen_he_serial) - - if not hist_obs_global: - print(f"assign global hist with result from {site}") - hist_obs_global = hist_obs_he - else: - print(f"add to global hist with result from {site}") - hist_obs_global += hist_obs_he - - if not hist_cen_global: - print(f"assign global hist with result from {site}") - hist_cen_global = hist_cen_he - else: - print(f"add to global hist with result from {site}") - hist_cen_global += hist_cen_he - - # return the two accumulated vectors, serialized for transmission - hist_obs_global_serial = hist_obs_global.serialize() - hist_cen_global_serial = hist_cen_global.serialize() - return hist_obs_global_serial, hist_cen_global_serial - - def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial): - self.logger.info("send global accumulated histograms within HE to all sites \n") - - model = FLModel( - params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, - params_type=ParamsType.FULL, - ) - - msg_payload = { - MIN_RESPONSES: self.min_clients, - CURRENT_ROUND: 3, - NUM_ROUNDS: self.num_rounds, - START_ROUND: 1, - DATA: model, - } - - results = self.flare_comm.broadcast_and_wait(msg_payload) - return results diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py b/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py deleted file mode 100644 index 1ef907c9f7..0000000000 --- a/examples/advanced/kaplan-meier-he/jobs/km_he/app/custom/km_train.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os - -import matplotlib.pyplot as plt -import numpy as np -import tenseal as ts -from lifelines import KaplanMeierFitter -from lifelines.utils import survival_table_from_events -from sksurv.datasets import load_veterans_lung_cancer - -# (1) import nvflare client API -import nvflare.client as flare -from nvflare.app_common.abstract.fl_model import FLModel, ParamsType - -# Client training code - -np.random.seed(77) - - -def prepare_data(num_of_clients: int = 2, bin_days: int = 7): - # Load data - data_x, data_y = load_veterans_lung_cancer() - # Get total data count - total_data_num = data_x.shape[0] - print(f"Total data count: {total_data_num}") - # Get event and time - event = data_y["Status"] - time = data_y["Survival_in_days"] - # Categorize data to a bin, default is a week (7 days) - time = np.ceil(time / bin_days).astype(int) - # Shuffle data - idx = np.random.permutation(total_data_num) - # Split data to clients - event_clients = {} - time_clients = {} - for i in range(num_of_clients): - start = int(i * total_data_num / num_of_clients) - end = int((i + 1) * total_data_num / num_of_clients) - event_i = event[idx[start:end]] - time_i = time[idx[start:end]] - event_clients["site-" + str(i + 1)] = event_i - time_clients["site-" + str(i + 1)] = time_i - return event_clients, time_clients - - -def save(result: dict): - file_path = os.path.join(os.getcwd(), "km_global.json") - print(f"save the result to {file_path} \n") - with open(file_path, "w") as json_file: - json.dump(result, json_file, indent=4) - - -def main(): - flare.init() - - site_name = flare.get_site_name() - print(f"Kaplan-meier analysis for {site_name}") - - # get local data - event_clients, time_clients = prepare_data() - event_local = event_clients[site_name] - time_local = time_clients[site_name] - - while flare.is_running(): - # receives global message from NVFlare - global_msg = flare.receive() - curr_round = global_msg.current_round - print(f"current_round={curr_round}") - - if curr_round == 1: - # First round: - # Get HE context from server - # Send max index back - - # In real-life application, HE setup is done by secure provisioning - he_context_serial = global_msg.params["he_context"] - # bytes back to context object - he_context = ts.context_from(he_context_serial) - - # Condense local data to histogram - event_table = survival_table_from_events(time_local, event_local) - hist_idx = event_table.index.values.astype(int) - # Get the max index to be synced globally - max_hist_idx = max(hist_idx) - - # Send max to server - print(f"send max hist index for site = {flare.get_site_name()}") - # Send the results to server - model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) - flare.send(model) - - elif curr_round == 2: - # Second round, get global max index - # Organize local histogram and encrypt - max_idx_global = global_msg.params["max_idx_global"] - print("Global Max Idx") - print(max_idx_global) - # Convert local table to uniform histogram - hist_obs = {} - hist_cen = {} - for idx in range(max_idx_global): - hist_obs[idx] = 0 - hist_cen[idx] = 0 - # assign values - idx = event_table.index.values.astype(int) - observed = event_table["observed"].to_numpy() - censored = event_table["censored"].to_numpy() - for i in range(len(idx)): - hist_obs[idx[i]] = observed[i] - hist_cen[idx[i]] = censored[i] - # Encrypt with tenseal using BFV scheme since observations are integers - hist_obs_he = ts.bfv_vector(he_context, list(hist_obs.values())) - hist_cen_he = ts.bfv_vector(he_context, list(hist_cen.values())) - # Serialize for transmission - hist_obs_he_serial = hist_obs_he.serialize() - hist_cen_he_serial = hist_cen_he.serialize() - # Send encrypted histograms to server - response = FLModel( - params={"hist_obs": hist_obs_he_serial, "hist_cen": hist_cen_he_serial}, params_type=ParamsType.FULL - ) - flare.send(response) - - elif curr_round == 3: - # Get global histograms - hist_obs_global_serial = global_msg.params["hist_obs_global"] - hist_cen_global_serial = global_msg.params["hist_cen_global"] - # Deserialize - hist_obs_global = ts.bfv_vector_from(he_context, hist_obs_global_serial) - hist_cen_global = ts.bfv_vector_from(he_context, hist_cen_global_serial) - # Decrypt - hist_obs_global = hist_obs_global.decrypt() - hist_cen_global = hist_cen_global.decrypt() - # Unfold histogram to event list - time_unfold = [] - event_unfold = [] - for i in range(max_idx_global): - for j in range(hist_obs_global[i]): - time_unfold.append(i) - event_unfold.append(True) - for k in range(hist_cen_global[i]): - time_unfold.append(i) - event_unfold.append(False) - time_unfold = np.array(time_unfold) - event_unfold = np.array(event_unfold) - - # Perform Kaplan-Meier analysis on global aggregated information - # Create a Kaplan-Meier estimator - kmf = KaplanMeierFitter() - - # Fit the model - kmf.fit(durations=time_unfold, event_observed=event_unfold) - - # Plot and save the Kaplan-Meier survival curve - plt.figure() - plt.title("Federated HE") - kmf.plot_survival_function() - plt.ylim(0, 1) - plt.ylabel("prob") - plt.xlabel("time") - plt.legend("", frameon=False) - plt.tight_layout() - plt.savefig(os.path.join(os.getcwd(), "km_curve.png")) - - # Save global result to a json file - # Get the survival function at all observed time points - survival_function_at_all_times = kmf.survival_function_ - # Get the timeline (time points) - timeline = survival_function_at_all_times.index.values - # Get the KM estimate - km_estimate = survival_function_at_all_times["KM_estimate"].values - # Get the event count at each time point - event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events - # Get the survival rate at each time point (using the 1st column of the survival function) - survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values - # Return the results - results = { - "timeline": timeline.tolist(), - "km_estimate": km_estimate.tolist(), - "event_count": event_count.tolist(), - "survival_rate": survival_rate.tolist(), - } - save(results) - - # Send a simple response to server - response = FLModel(params={}, params_type=ParamsType.FULL) - flare.send(response) - - print(f"finish send for {site_name}, complete") - - -if __name__ == "__main__": - main() diff --git a/examples/advanced/kaplan-meier-he/jobs/km_he/meta.conf b/examples/advanced/kaplan-meier-he/jobs/km_he/meta.conf deleted file mode 100644 index 5c81903a41..0000000000 --- a/examples/advanced/kaplan-meier-he/jobs/km_he/meta.conf +++ /dev/null @@ -1,7 +0,0 @@ -{ - name = "fl_km" - deploy_map { - app = ["@ALL"] - } - min_clients = 2 -} From fb950ed2e3c52a8cc0292bab521b47ee22e3301e Mon Sep 17 00:00:00 2001 From: ZiyueXu77 Date: Fri, 5 Jan 2024 15:05:33 -0500 Subject: [PATCH 05/20] add readme with details --- .../app/config/config_fed_client.conf | 116 ++++++++++ .../app/config/config_fed_server.conf | 23 ++ .../app/custom/kaplan_meier_wf.py | 158 ++++++++++++++ .../kaplan-meier-he/app/custom/km_train.py | 206 ++++++++++++++++++ .../jobs/kaplan-meier-he/meta.conf | 7 + 5 files changed, 510 insertions(+) create mode 100644 examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf create mode 100644 examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf create mode 100644 examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py create mode 100644 examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/km_train.py create mode 100644 examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf new file mode 100644 index 0000000000..9de6ad8d7c --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf @@ -0,0 +1,116 @@ +{ + # version of the configuration + format_version = 2 + + # This is the application script which will be invoked. Client can replace this script with user's own training script. + app_script = "km_train.py" + + # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. + app_config = "" + + # Client Computing Executors. + executors = [ + { + # tasks the executors are defined to handle + tasks = ["train"] + + # This particular executor + executor { + + # This is an executor for Client API. The underline data exchange is using Pipe. + path = "nvflare.app_opt.pt.client_api_launcher_executor.ClientAPILauncherExecutor" + + args { + # launcher_id is used to locate the Launcher object in "components" + launcher_id = "launcher" + + # pipe_id is used to locate the Pipe object in "components" + pipe_id = "pipe" + + # Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. + # Please refer to the class docstring for all available arguments + heartbeat_timeout = 60 + + # format of the exchange parameters + params_exchange_format = "raw" + + # if the transfer_type is FULL, then it will be sent directly + # if the transfer_type is DIFF, then we will calculate the + # difference VS received parameters and send the difference + params_transfer_type = "FULL" + + # if train_with_evaluation is true, the executor will expect + # the custom code need to send back both the trained parameters and the evaluation metric + # otherwise only trained parameters are expected + train_with_evaluation = false + } + } + } + ], + + # this defined an array of task data filters. If provided, it will control the data from server controller to client executor + task_data_filters = [] + + # this defined an array of task result filters. If provided, it will control the result from client executor to server controller + task_result_filters = [] + + components = [ + { + # component id is "launcher" + id = "launcher" + + # the class path of this component + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + + args { + # the launcher will invoke the script + script = "python3 custom/{app_script} {app_config} " + # if launch_once is true, the SubprocessLauncher will launch once for the whole job + # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server + launch_once = true + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + }, + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + # how fast should it read from the peer + read_interval = 0.1 + } + }, + { + # we use this component so the client api `flare.init()` can get required information + id = "config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = ["metric_relay"] + } + } + ] +} diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf new file mode 100644 index 0000000000..618bb4b0b5 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf @@ -0,0 +1,23 @@ +{ + # version of the configuration + format_version = 2 + task_data_filters =[] + task_result_filters = [] + + workflows = [ + { + id = "km" + path = "nvflare.app_common.workflows.wf_controller.WFController" + args { + task_name = "train" + wf_class_path = "kaplan_meier_wf.KM", + wf_args { + min_clients = 2 + } + } + } + ] + + components = [] + +} diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py new file mode 100644 index 0000000000..b775afaaa7 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py @@ -0,0 +1,158 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict + +import tenseal as ts + +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType +from nvflare.app_common.workflows.wf_comm.wf_comm_api_spec import ( + CURRENT_ROUND, + DATA, + MIN_RESPONSES, + NUM_ROUNDS, + START_ROUND, +) +from nvflare.app_common.workflows.wf_comm.wf_spec import WF + +# Controller Workflow + + +class KM(WF): + def __init__(self, min_clients: int): + super(KM, self).__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.min_clients = min_clients + self.num_rounds = 3 + + def run(self): + he_context, max_idx_results = self.distribute_he_context_collect_max_idx() + global_res = self.aggr_max_idx(max_idx_results) + enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res) + hist_obs_global, hist_cen_global = self.aggr_he_hist(he_context, enc_hist_results) + _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) + + def distribute_he_context_collect_max_idx(self): + self.logger.info("send kaplan-meier analysis command to all sites with HE context \n") + + context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=4096, plain_modulus=1032193) + context_serial = context.serialize(save_secret_key=True) + # drop private key for server + context.make_context_public() + # payload data always needs to be wrapped into an FLModel + model = FLModel(params={"he_context": context_serial}, params_type=ParamsType.FULL) + + msg_payload = { + MIN_RESPONSES: self.min_clients, + CURRENT_ROUND: 1, + NUM_ROUNDS: self.num_rounds, + START_ROUND: 1, + DATA: model, + } + + results = self.flare_comm.broadcast_and_wait(msg_payload) + return context, results + + def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate max histogram index \n") + + if not sag_result: + raise RuntimeError("input is None or empty") + + task_name, task_result = next(iter(sag_result.items())) + + if not task_result: + raise RuntimeError("task_result None or empty ") + + max_idx_global = [] + for site, fl_model in task_result.items(): + max_idx = fl_model.params["max_idx"] + print(max_idx) + max_idx_global.append(max_idx) + # actual time point as index, so plus 1 for storage + return max(max_idx_global) + 1 + + def distribute_max_idx_collect_enc_stats(self, result: int): + self.logger.info("send global max_index to all sites \n") + + model = FLModel(params={"max_idx_global": result}, params_type=ParamsType.FULL) + + msg_payload = { + MIN_RESPONSES: self.min_clients, + CURRENT_ROUND: 2, + NUM_ROUNDS: self.num_rounds, + START_ROUND: 1, + DATA: model, + } + + results = self.flare_comm.broadcast_and_wait(msg_payload) + return results + + def aggr_he_hist(self, he_context, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate histogram within HE \n") + + if not sag_result: + raise RuntimeError("input is None or empty") + + task_name, task_result = next(iter(sag_result.items())) + + if not task_result: + raise RuntimeError("task_result None or empty ") + + hist_obs_global = None + hist_cen_global = None + for site, fl_model in task_result.items(): + hist_obs_he_serial = fl_model.params["hist_obs"] + hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial) + hist_cen_he_serial = fl_model.params["hist_cen"] + hist_cen_he = ts.bfv_vector_from(he_context, hist_cen_he_serial) + + if not hist_obs_global: + print(f"assign global hist with result from {site}") + hist_obs_global = hist_obs_he + else: + print(f"add to global hist with result from {site}") + hist_obs_global += hist_obs_he + + if not hist_cen_global: + print(f"assign global hist with result from {site}") + hist_cen_global = hist_cen_he + else: + print(f"add to global hist with result from {site}") + hist_cen_global += hist_cen_he + + # return the two accumulated vectors, serialized for transmission + hist_obs_global_serial = hist_obs_global.serialize() + hist_cen_global_serial = hist_cen_global.serialize() + return hist_obs_global_serial, hist_cen_global_serial + + def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial): + self.logger.info("send global accumulated histograms within HE to all sites \n") + + model = FLModel( + params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, + params_type=ParamsType.FULL, + ) + + msg_payload = { + MIN_RESPONSES: self.min_clients, + CURRENT_ROUND: 3, + NUM_ROUNDS: self.num_rounds, + START_ROUND: 1, + DATA: model, + } + + results = self.flare_comm.broadcast_and_wait(msg_payload) + return results diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/km_train.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/km_train.py new file mode 100644 index 0000000000..30742eba2c --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/km_train.py @@ -0,0 +1,206 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import matplotlib.pyplot as plt +import numpy as np +import tenseal as ts +from lifelines import KaplanMeierFitter +from lifelines.utils import survival_table_from_events +from sksurv.datasets import load_veterans_lung_cancer + +# (1) import nvflare client API +import nvflare.client as flare +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType + +# Client training code + +np.random.seed(77) + + +def prepare_data(num_of_clients: int = 2, bin_days: int = 7): + # Load data + data_x, data_y = load_veterans_lung_cancer() + # Get total data count + total_data_num = data_x.shape[0] + print(f"Total data count: {total_data_num}") + # Get event and time + event = data_y["Status"] + time = data_y["Survival_in_days"] + # Categorize data to a bin, default is a week (7 days) + time = np.ceil(time / bin_days).astype(int) + # Shuffle data + idx = np.random.permutation(total_data_num) + # Split data to clients + event_clients = {} + time_clients = {} + for i in range(num_of_clients): + start = int(i * total_data_num / num_of_clients) + end = int((i + 1) * total_data_num / num_of_clients) + event_i = event[idx[start:end]] + time_i = time[idx[start:end]] + event_clients["site-" + str(i + 1)] = event_i + time_clients["site-" + str(i + 1)] = time_i + return event_clients, time_clients + + +def save(result: dict): + file_path = os.path.join(os.getcwd(), "km_global.json") + print(f"save the result to {file_path} \n") + with open(file_path, "w") as json_file: + json.dump(result, json_file, indent=4) + + +def main(): + flare.init() + + site_name = flare.get_site_name() + print(f"Kaplan-meier analysis for {site_name}") + + # get local data + event_clients, time_clients = prepare_data() + event_local = event_clients[site_name] + time_local = time_clients[site_name] + + while flare.is_running(): + # receives global message from NVFlare + global_msg = flare.receive() + curr_round = global_msg.current_round + print(f"current_round={curr_round}") + + if curr_round == 1: + # First round: + # Get HE context from server + # Send max index back + + # In real-life application, HE setup is done by secure provisioning + he_context_serial = global_msg.params["he_context"] + # bytes back to context object + he_context = ts.context_from(he_context_serial) + + # Condense local data to histogram + event_table = survival_table_from_events(time_local, event_local) + hist_idx = event_table.index.values.astype(int) + # Get the max index to be synced globally + max_hist_idx = max(hist_idx) + + # Send max to server + print(f"send max hist index for site = {flare.get_site_name()}") + # Send the results to server + model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) + flare.send(model) + + elif curr_round == 2: + # Second round, get global max index + # Organize local histogram and encrypt + max_idx_global = global_msg.params["max_idx_global"] + print("Global Max Idx") + print(max_idx_global) + # Convert local table to uniform histogram + hist_obs = {} + hist_cen = {} + for idx in range(max_idx_global): + hist_obs[idx] = 0 + hist_cen[idx] = 0 + # assign values + idx = event_table.index.values.astype(int) + observed = event_table["observed"].to_numpy() + censored = event_table["censored"].to_numpy() + for i in range(len(idx)): + hist_obs[idx[i]] = observed[i] + hist_cen[idx[i]] = censored[i] + # Encrypt with tenseal using BFV scheme since observations are integers + hist_obs_he = ts.bfv_vector(he_context, list(hist_obs.values())) + hist_cen_he = ts.bfv_vector(he_context, list(hist_cen.values())) + # Serialize for transmission + hist_obs_he_serial = hist_obs_he.serialize() + hist_cen_he_serial = hist_cen_he.serialize() + # Send encrypted histograms to server + response = FLModel( + params={"hist_obs": hist_obs_he_serial, "hist_cen": hist_cen_he_serial}, params_type=ParamsType.FULL + ) + flare.send(response) + + elif curr_round == 3: + # Get global histograms + hist_obs_global_serial = global_msg.params["hist_obs_global"] + hist_cen_global_serial = global_msg.params["hist_cen_global"] + # Deserialize + hist_obs_global = ts.bfv_vector_from(he_context, hist_obs_global_serial) + hist_cen_global = ts.bfv_vector_from(he_context, hist_cen_global_serial) + # Decrypt + hist_obs_global = hist_obs_global.decrypt() + hist_cen_global = hist_cen_global.decrypt() + # Unfold histogram to event list + time_unfold = [] + event_unfold = [] + for i in range(max_idx_global): + for j in range(hist_obs_global[i]): + time_unfold.append(i) + event_unfold.append(True) + for k in range(hist_cen_global[i]): + time_unfold.append(i) + event_unfold.append(False) + time_unfold = np.array(time_unfold) + event_unfold = np.array(event_unfold) + + # Perform Kaplan-Meier analysis on global aggregated information + # Create a Kaplan-Meier estimator + kmf = KaplanMeierFitter() + + # Fit the model + kmf.fit(durations=time_unfold, event_observed=event_unfold) + + # Plot and save the Kaplan-Meier survival curve + plt.figure() + plt.title("Federated HE") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + plt.tight_layout() + plt.savefig(os.path.join(os.getcwd(), "km_curve_fl.png")) + + # Save global result to a json file + # Get the survival function at all observed time points + survival_function_at_all_times = kmf.survival_function_ + # Get the timeline (time points) + timeline = survival_function_at_all_times.index.values + # Get the KM estimate + km_estimate = survival_function_at_all_times["KM_estimate"].values + # Get the event count at each time point + event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events + # Get the survival rate at each time point (using the 1st column of the survival function) + survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values + # Return the results + results = { + "timeline": timeline.tolist(), + "km_estimate": km_estimate.tolist(), + "event_count": event_count.tolist(), + "survival_rate": survival_rate.tolist(), + } + save(results) + + # Send a simple response to server + response = FLModel(params={}, params_type=ParamsType.FULL) + flare.send(response) + + print(f"finish send for {site_name}, complete") + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf new file mode 100644 index 0000000000..5c81903a41 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf @@ -0,0 +1,7 @@ +{ + name = "fl_km" + deploy_map { + app = ["@ALL"] + } + min_clients = 2 +} From c072d31a23973835fae8698ffc11d729d5874a5b Mon Sep 17 00:00:00 2001 From: ZiyueXu77 Date: Fri, 5 Jan 2024 16:02:50 -0500 Subject: [PATCH 06/20] add curves, modify saving functions (curve and km details) --- examples/advanced/kaplan-meier-he/README.md | 8 +- .../figs/km_curve_baseline.png | Bin 0 -> 16543 bytes .../kaplan-meier-he/figs/km_curve_fl.png | Bin 0 -> 17235 bytes .../app/config/config_fed_client.conf | 2 +- .../app/config/config_fed_server.conf | 2 +- .../{km_train.py => kaplan_meier_train.py} | 76 ++++++++++-------- 6 files changed, 49 insertions(+), 39 deletions(-) create mode 100644 examples/advanced/kaplan-meier-he/figs/km_curve_baseline.png create mode 100644 examples/advanced/kaplan-meier-he/figs/km_curve_fl.png rename examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/{km_train.py => kaplan_meier_train.py} (79%) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 026b7f3569..f168f1a606 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -41,12 +41,14 @@ python baseline_kaplan_meier.py ``` By default, this will generate a KM curve image `km_curve_baseline.png` under the current working directory. -Then we run the federated job with simulator +Then we run a 5-client federated job with simulator ```commandline -nvflare simulator -w workspace_km_he -n 2 -t 2 jobs/kaplan-meier-he +nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he ``` By default, this will generate a KM curve image `km_curve_fl.png` under each client's directory. ## Display Result -By comparing the two curves, we can observe that the two are identical. \ No newline at end of file +By comparing the two curves, we can observe that the two are identical: +![KM survival baseline](figs/km_curve_baseline.png) +![KM survival fl](figs/km_curve_fl.png) diff --git a/examples/advanced/kaplan-meier-he/figs/km_curve_baseline.png b/examples/advanced/kaplan-meier-he/figs/km_curve_baseline.png new file mode 100644 index 0000000000000000000000000000000000000000..34cdb9cabbdab5aa69026c7a00bf897d0b5502f2 GIT binary patch literal 16543 zcmb8W1z1&Uw>CU!1OyBuq$CWaQxOCOge3?{cLUV&$~4YIf#W7ZZm&$PE*${arikU8~!y&UYLft?X=tcrWk@arO- zIL49Ut6QPlqjzNDMbA$3n~}Lb8cQ$vMsAh-I1Te1gAO-aJtO(-UMvcI=cm4mjOTtcj?QHZ=`TS~V*^Cpt7tt`H`j@&q&-S?x)cr;2y#)p1`epYD>BKom~8RFELZN9XZiv01-k%3&sS*?D=rvpqRZ z`&-j4+NIKjoor((P5vw!lJOYKhlmkZ+*PD|en z+}zw4E_;?Sd(6Kz&mS%@dDIlbl=J@mRVQ4R904&!hf|*ctz&5WboW?eU{Ai0lv}mo z`r@!&?d~Ms(;eJYnay(bH z)}57^kPS;ci19E&`oT0<=Ymg)@lJW<_vL8hXyWunwlyaP;65*+qG-3gzr4CAzPHw0 zjKy^+AVaWXyWX`49*Y-gXlNR}@hXnYsBAg zqF-P*ut@aE2w-+s=+3zD9BX!H=ht{>9|4sF^&#JnuILCjjS(HYt(D0ExVh-vZ%F|x z-rDH8@>b3ZO78y}5BErE7>P3S&Keld^%_4p($Sl%$KFRyPR@x@kz9XqHTmYyru$l_ zywj(1nU+859>AhM^^Eu0=#NN)iNQ)3RA8URmAz*-4-6WVLVb0racHXJ?21^>}V;nCqhD)g=Na$>i_|&2U zYfF?enttUN!do%JX5hKp=pY`tySe7ERGsNE;-%Hh=%_UGc7C|Bba-59v26Ty(7o8gT*?g4$e)4frLJkZG74ui4{T15Dsj1nm>BOkMx!|GIACIU+ zJ>t7@UoTF-N_G5lQC=~Mx2H{PH;A+plH!$FW_QGq}Wq-OMlWBlAM%k1z5z ziNa#O5s=n7Gn_oR{+^t#Wqf(&g_?5s=Sgr6H~Hy;hE_Y}SwwPP4VT&(6^@8tXKrGi zai8>A@)+sPO!Us)S**Ml4d0z8Phl=;gF-%v#MPe=qv#WIJkV9Zz z#Vl96{rf`|>%DsGFyThRx~Ip^%%+B3?#kB4QCCw-O-W(%T8$g|8o{Zke)A?PuI1?R zbCt_kFd`W$$+fTq*ReyB)6;Vwj;VYdDtB~R`FQ5cq)^737oOSafFsp`bl|XcLM8vZ z-`3_uYIavX$MTyzyng-qg#Cr1;&q;I=FrnSx$cg7vs}Y!aojlg`WE#_a^Lm!bp}DR zCy{(c2_1WtE=#>cwE9VfU&Iv$OYU`RXc?q`=B-p}mpbr_kcTK57Q<~oG;e=l4F+b} zm4@N@E@byzLUOfLXrRW+V`5;l_P}dlq}B*5QlI6M~bQzsXG90|Ksd=fb=D_lbsssC+dvG`M(q zL&`P5&sS84Gl^przg09%8dOCwLig{_v29TwPa} zVXLO3MAK_~t@BKxX9^e!=cMG}hK}T>q*(?A4VSW8w7IXtI3ztLwbyap>Z;$5A&6Nf zJOp+229?eW9dkmLk&XF&wVs|HfxhgVoSt00(y~IIh>3M~Sj;$sC2&Gr_(>cz3bW_3 z>RH0N zPhZKs!J>PEq8Ryss=I?FH#|1x!?9*t&f9Ca@Ewb`*mNqX-J59eDF*XL9fJErxIA{f zM=R6x7fe3Vq8#}(6=!D+zKcQhZZt<$QIR@4r#nS1kOPzGv+u@1alPx%nO8}uuQ(}s zwDAxk?P`R~Bst^ScG!;IDxvq^lc7>7$tpW5!HyU8L5>);WjdwxM>(?)9@52ISs zm-}|_8R-M~B1;=zS*Y}L7Vme)$+B10KZ%Z$77yVAD5#@oMKW~qKBr5xhttZkLo^^}C=`kjU z2VF4?QEb>gE?T*7a-*73gjh`q9%FnSzIgy(Mn;MuYkg2ki&iCsQ8~a|Qa@(V1O^kw ziex-(7!B)Nh!AV$3}7GrJ;L`#`Tw&K!b<;zv>2}FwVL=Ska*!?A_j3cZ9~qtvBO&T z`y-YD`Ad1>eO?@ElODqigZWRW0lF#4Aw|)=35;%z`YS{ae=iz6j`=8Tjv%3Q;K%hc zHX-le`kS%x1PJo(h#XAi?Efw2uyG~J;&M=kjE2bF9diPbIzfGN#GWDA435D-*IJ9 z<1hSEOtuuC`siMmbQ~o{kmyiR{u7Qp?1OARm*38ujrY2Kf14N`{ZS#`6Jb%~T?LmP zx8?b0tL4iC*MFsgH82w~M|kQ+?Pro%g1>GrB<2L3k*Ca*DP+!5+AIL`jATP}{r=t@ zk~5EBQ#fuPf(`#qP;TI*c`95IUs8jxAB^sFGF4Jzbxq|p!&1d9!_uaW^Qy2#!!j^i z`VhqOzJbKW2${dOffBX>AGYE1dYKuWsww}A#bkIicRjQ8|6^5HQM1t&MR<*B^dazT z*xKe@3uZ{?(JLu(KX`ZNv#@a;GC4V!<+Aj>VtY2n>DSNEc@kOKgX;Ps#XO?*3W^q-i&_)AvRvb!&CnW_hg9eRJ4ly^es&a=5~&VzFXQRl)&A zHCJec=vQE~cx!hZ+Fo5$dqjM=>(k5gFB22<0Rc8vAM|J$^+XGpVsYtF*>)o};t(@M z?tFS7YH=8j!F-DYUdPgA&bxQ2<4wVV`i|o93}UvgxVX5+YNmpOzTCWX=MKkDuV149 z?f^?rWZzspGQFCWm&ZFr4%@-Y?Q&fI#*N^I^-0eH15NX2X=!bmLzwOxIiiot$hcTV^-69IrGjUkU2bOS0}qyJ zbg-y}h_4<6j(4x%=Uj5fr z-5I!dpLw5i#JkUCW#{Mj(P-%>qi{#r;_jDq_i8}jAGcuLkP{Z=^wOctsiO@6IaO5> z4m`)nvZ|`|Q{_WLc4Gbg{moCcPdZ>CC>j*L8$Yk3C7mshqT93x*uBEEMuhF4?HFzP zwF$Q=YIG3tjeszfO)q3_`QgbC6qrWx=w)ngyQBLNA?&q{U7zhygqcUA9XfW}64#ZE zI(XDR0D&5j)AD}xJ9LcgO2_Pwa==Qpt;jF@6{7uPiwp#(z)56plBh3VxiR7X$ML?1 zdrqH7g&^7ioCH56K0aS>q&2eaOp(jZ&hF|+mMQk$+g31XNs#ooZN6(;HcEVvk#g9+ z;qdlaw~AA%K(pn0ze5$cjJU4#BLh5Q&FSFEkCW0{|0H4Sdvr>Mxm1mlXp780q}6;I z{|bzQ%umq98RDt-qlT6ukuHwcIXP-quAuvY#!xJk{Na6HVjg$Hh@!EsZVlHLXQrm* zBOji0+BnxZJ63Gw51Wkn+CP80{5Qr?;Qh;k@qD2r#o3$kk7vi=A-+y;AGD6;xtqmt zSV`a4;z{6FpbTKICO}VUD7`fKr_cATv=iS)+g!?_YM~`fw~&T&lfzG7d4)1ShHB3; zt=H~A|Dv*%*$Ffarhwz?aDj29!r;?=K@8Yvt6T+a(|?)~&k#h{Q*J$q;3(}*9%9t` z?S=)3@c0SJUs3k#5M=*Kx)$+b$;sw!E!$OA1W_P`uiS=|Uml-0PbWA*PT2E<{4BL# zmn1jX=_yiQ`cLEg?Hg?n0)yG%xW)Lh!5II}+ZF3{*z}KZ-=YDl)|E$ES!?QiHn!b{ zl`k@enA6lB)u>01Bqzc;3cXKQ_X2PUiKclifPC@T*LDwg~c|jd> zrRL!&nHuSvhiiQCPqB+U0&hxovFr)i4!O80o?_q!lT@$Q!(%l_uR1#=*2g zqOp7>_~F5(B*aq>YD9?eC5*O?`wa^mtN(gR$pe}F19hs(c;OIYq;3(${fUaJ{!^(%xE;2YS zR9BNsV#Qxf-~Tvg-V5G}pCDR;kuu45 zzP+)K14&shgUEHglDk&x`+GZ^HUQqy>>0p1nnKM%j&`Hd5=>VxfYSo$%hk*5>Qc?s zFIP>y7TkpfTkl=hg`sjaN%u8N=h~5_zS5d?p7ohdnNZ1%X!rHm^iZjN?Z=NFXR%1Y zW;W^wBD@9_>X2KOn}2zA4x?M{ke=W+MW3r%e0_Dg-Ol*{vb`d^k;~N(xj%jU7|yZ2 zFevs?=F8bmFTB)Fb~w=<1^UR^&R+JLH(w!VVe5OPoY@q7GV|H9XJz?Yb4CZ*ovHHZ zEFdIhV!;j1py_Ru+vz1=J;{h8u!h`d)w0m zY@vf$5f43`mZyYwKHa~FL+^Y8vP@^$v&U6KnI*B1hg$24l(=4v0e-jEB0Q!wA)pjam-C+ASE55!+cxuJnUF2tQCueP_% z%sH=Ky^5avuBLoZ4pP6BoxK%T34u+IYC|C5R!P7j!tl7o2xgXDg*|`Szl@5K&wG6J z6#Q$mHq&|ENXWdEuKHkqhwSmya2v?ry5Nq3+`F*j-{L1BQF|E~Ia{-{*v+#4lcXBR zhHadRPxkiqwt~jbHC(spGgpBEbT7CgN8 z4Ec<`>F#wC7|RxSDIquJ`5Bd7S|iV7xxf=y+1b}+0w{a7*XP(DpF%Ua?ybam6p=|D z?5_Ssa06DESz7sq=o6;e;-cWLM{DZB!=@a>C+@rOeiqB6&S0ZMeCH?mDH=BJBmz+3 zI(rsYGZ!eH$i>kj2!A)@M)Gexf@}+Qj;diS$MC7dh(#RNPe!DX`1~sNd-zI@3GJ9O z;Jq&68~X?5kw%(XYk>P~Q6{Y8mX?E8mQ?txy4aS3M6buTnU9hV0cvMOQerZi-(wgL zlOWQfr;rr--&VKVa`~ze>uX{o)_3i2R?I~4S;wATKA}8jc>$7yLwFs;o(~TE#PSEv z1w6V8uKXK4d<2Oa{no1QSe~wg3H2$s?Bw{_GVaOymqe1vYzpT93VdiKEonaGa}yFT z6p@t)eA|F%@5?Db5S{+V*Bn3OQ6qiU{B(ijYcgVlwUQbsjBYSdUb>9B-WFa$X%0Ws zCBGx_K)|%_`qS5OPm(j9>$1Pd)MmZ}_o|MGrTFyuv(xBf+S#<{%0QIGr;+nj-Sj!w z8ENw*FREX_o*kLfH04iwRCCQsbCA`CB*z>4zq23@@*So8ma1027TajG#b45Loc^Lw zYCDVt2yFudsd@F`fv`WjC%*IIcJSpa1$noj@UdWiBi6Kr;Cgc-|0}luJ~{v3SK2P! zzu(N1iKYR-;0p9b@fqdC+12PIdT(dZ+dV)2f7@2{6l6cNBe|Nib|H+ER#UTLy46yh zr#t1h@kFiHFmj((%O;#-)$Rf6INA;9Y=GhixLK5T6RGqDelEMDr1wqf6%sc+OA00) zxDfoDk!Ptcm}|0tFa+TbMQmXx6ZZ5>LPv3*;Yr6mU@5zRv$X6TOnMY6>L1ek{r6S4 zL!^lG6iywfn#-$X^I;aY4707koxV3vgaMo#$x*#AVBHXg9?%)$d%ZH3d)O z22t7>Rsc)sYIr|^-FvdNKa(2TdugdE6=?c14!H#dH*+*|yAF#Db;OiHIP=Ab(;$#p z-Rr$zd}NgKcK%?$xd0e@ixS3)kcpT3w6s~mPs2Ld^Jr|I5-2{h2JBp#nu??ba8R(n z-By|L8E6&4yoM-()$p->3Qav&{%W)+K^YdO%A;R)({ZY`6Ot$nOq`eRxb5Eqd5<#f32k$U8EeNy=?FG9HS!$&zDv*a5u*H zGI;QOzz)R6gM?xYq~QN+z}UvEi^BPw;poOUuJztSxlG~88rxy<^`()KH>F^p(r02y z2&>FK9=!qFH4w)RW>i1U=U1hytgMvP!wErfCa{xGAu+hu{|+3poI3=|btoWYzkaO% z!L-w*W|gV%JD?RL-s}^Fv8idQ_sa7`%hC6b-66qi`}GpOQ8=n*jyO6vtVPxCHU=3z zrf_Ou`Fh1j8a-vU6d4|HWRqxM{3@}+7?sOgX*#u(Cw}~9YhHuZw%~Yg{1;H>tI^8^ zGwhw(v_V0ZcC?ioD}?htl*!dHbdzfcpz6CN&qTAj3e)I%Q+oB8teBrI&ho;aU_~S& zkDtOvt^SBFN&2j}F#rafg)ZXx;{l1fw)X2!pRQ|YHYIS3+Ft++#pcO9^ktSnpszu` zq0#2V!u@7t4l&}3R{=Pk7vez{zgJ&cnf~cu(%sLdZM62*KH>sEM_+EU0cL4j3-^y7 z^^ac0#i79CFtdaxKS8d9m-VsXBH<@Ka@e9wblu?N!Lw6El=zj+TqwACrxD9@?DQ9ska#@veJEfTdr_! z^Sehr6q>Ai2U#26Yt-L(ob$(~?-Ab=S0!8lu6_pKMg<3Mz0ar*qA{9$LbSZLs0S25 zj9sAUFZ=_mPxB7qvBmeS4x{;n8A~_*kBx;vr)^xOV#MWNKL{|Yw{C?4bSVNNqs5Dx zJ|hHZEyzB>fp^8_|H1T?GR9T=K5$5cVoo4pwnLV^`42G;j{3lIopsisTW_~VZ zpnef2kNBlc(Y$}EPIx8-H)jPum)<^toH|!>wM_DlOM#2)YXEi5KR}+mAlD*N1HU2d zuDx-#Z1Qtd4$w27+S+l;!be8>koM)-0mM3XwkjR6Eb{ljj)D4Y2Fe z3^m53GDbZ|vyQG05qS2%M^r$uLBw%N=`e$^!g;Mcb#Lz)&0O7{F&p4@z}&KdqMw0U zfzqhuDNK#DLcqyK;zL0H6*10A*k=AtHo5okkucuKL4ujh9St3wEF$s~vUwNdm7!{< z4CO_hn#HD*cItpM*Tz{6v=4qYOQ{N|%f_tR-j(x)?&)1!8eQF*Slwb9>Feona;jDsDs&}1HtsqfK=m#u^zydCvC~RV^vNHG5FRZZ%#FX~g5rZ`Zf$>8Eb{V{ zt35VFS53_x|9oq&`dPl=B4fPGFHAdWvZ$n{^W4a17#yt1wBY#RZ$aJnfSOp7RXlw-Vfls)l5gS zZmqXh8oyb4f~12uPk;Sso3z^}@^E)=yOzKXazt?wByqIG2%TAs$mf7)pViZsXIMS3 ziw9xM>`+E0oUpr1?}U5{62Qq_IW2cz59FwG~W;obyLfM_W%Sa@mV99-eh$Aeu6%sv^MiS6G)b6 z@u<(y8lJFKk+ZKV;UDERiF|~IShC==s`6IR7NpYD--zP5$dd=;L~WE%@{Jxq=AyU0lhWvo*O+nv^IK*C7;zbGGm0#Y?J zAatPtL1O`+)Eu~2wZB=0@M_ofec$V3E62zlbjqn@w3M@gE-e+Efb>p*a)9#Zw)89AbB>Lf6-l#T*HD=;Sfs zi0L|U+LKE`?bT6yMB0&1I_}p)l6GLJ zRG8!gpqvmByiSdJb9KT1kztmW1Ze~b0q=h9y?g_wJNGLch^Qi7OHIa4A_Wj+kR6QX zMyS6{`&;{u*FM9B!*AcKcgi;uZ)%sBrQU(L$ip4Q3(Y%HLLb|x=tcK-2*UEhfVeB@ zKD3sA9ey*|7Bk1Y^&aL$RtkoM#5YsGcq=%AIMfEg27K8`eG6Wh75iZbd2R#T2hEKv z)P;SkQA--l=ujOCP98sCUd?&fV0nZ9DL4iulVPQ1iHWzSh@a0k*8Kp%54UJ7;DTo9PHKRsR$9XlQ7Qw^z0w@ z)Jo3JEjOV4^X@RLIOq2Wd#=bw{Vb$N%W2pX{W7jHcAEN5#rMKb!RAaakWrpHmPZEk z8@8kGb-VS+iKV(_Td-0)lEIH))Gn=HXDK-kFR7xKxD}p#{fJ#Irc(-rXL5JyAP`eJ|3bejmYXo-&eLTirmIK zz4Ich_ru`t9YiB?xM=+0E6d76R$Mfhenn}xed!HCeB=$ROAw707+7bfpX`A|kK+3q zQHA|LfEM7x+|#R@KRi|;vbwRhVlN)$*Dym*iYSP|7s*z)6fDTDrbR{tbeD4|wM!$u z8nDlY%vp~0Vcp8oIY%u&8Zu{Cup)I_uzM7R7B#xM%*qyx{E`H}2kOV8U48pP5W4p2 z1$PmB;I+JW21kISzb@lX(K+lgk{ix0-&DPck{D<%0_AhX2!ARTG_5Od+%b`Qs!5Sa zz5{p5GW+OQ3iWA};^Hx^($*`r?SaDYZ;ETUQfzS#G+SLS&Z!mMAHSxduKpU*Ahg1> zj>}BUaGl_P&PV(qY>S#gG}qq2RG{hq z2n1+0_oeo5D%<_--Wka<0oEYQsQ`M0hU~Grp<%*Ucmu`gvmZ^yPy-DXpRxIBQG7VV zPx?GC(~xP#di`OdTkn)}8iPIxa66B70jonB4r@JBb|1}6W?Y1YgF4d!fu2TMT7AoS?B}{@?7Mo{GDL1{lDuP~Mmlt$NFVR?nWxlpVW1e`^m7MC$MNrf z$q!u3S_pA>k90@O>_1o|hXtv3@c5_rzg}FsGrWX_+sr_=+W~oF@0i!Wvt$By=3A6T zPYC25Dt&^9`qvCvjsLyUr2qW%jfgabE6uMiI_i23-zy!?h_kK0fHX$8O?$SZ_l}r0PJ(L{fA4^IGvd0UqKIh|1!N3Y+2upN0_CZ0WRxEx zWN|$rLnRR+t$Ws+H%bxMx!h7~s14~Bng^r2D$s@))hJ$rub?L?@}i7R#a*nA5ckrw zSL2CiINJ8?a|424{r(?yr3vi*w9zgIX2O72h2j(As>caF!#75Ivszh<;t{Oz8uc*_ z4*{AkjH1j}<;S_%S1!n8L2xY{g+6Qe&EWmH|3S7F`E(g@|#reT&R(ER}t)UCNpO2meQxsm#|xZWcQotYdp<2xj8W94??xo(p}E^X*dR(w>e z^yCeVo!bQi#rK_p$#<00awYl5WPzLjr&k=yOO@$}xzMfKRHT*Cmfp`tiu8h;)C{a8 zjN5)T8d4PW)VybMbVD@Cyp|QwCcvXE3156Rx>544&xQYdP6D5s^gJ6?GGew>G5-6f zw;Pd{Gb{Ebq)W7BQmjZC6&`hguWV#M172tL4&GoFww84V50BN0)K~fimg!v7lb=dS z8k+6^(>Ss8)Dsd!1NFp8s%^l{nE6VA@H#vcF2cd}gmCwK9PWZGo> zX)H)mImM5C>)(6$bdg5**bXJ089NN?Do-2how0didEL*^Z;kkw+I|0Z=K^0P+!?xD z$=o|*#m`5vM-T-7oyJF=ZX7G>a<}0l5{7LyN`nxoV}d}em#%Bg^8B0!e(sl`+ceKG zzVq@3q2Nu$=8wgVWHcXc3^tn$BuM877HU7feUxA>vvtg!0JN;o9`+{kQc;PdR zv#0SO=O#4^>t(oBnw~p8rEB^)My9Ck%|S4H^$wDP7JjDqi*v2>$r|RAc)vGDw6n9n zL4AW+X+rzr`O?6Tr(^6_4jFUCXiCT-gm=?~d@dHNyWXB!v1e;!EegcsDT!|p7DOnW zPkxkxhqqWDpI!<3lX z_YOzb&4JlQHk@&nS%WCzZSLgj|=i^ci=hW=0U>DK1RgN=_x)rIH{qby}`t_!BB z^&y2Jt%Uds&#gQx4*4S7r^5SlHkc(XPE@vvXRMYW2!`zYMPQL8g%dF#=iR>kz0xD0q`o>kI89KF{tXn{+<(r+S%RJc9qd5 zcxhL&59Y+t`6VU8460&ArLuqYh?1_)`voWEMm{Lk-+LXX&>gnozN@X+a z^jhF`Uin2U2Aw%Zf1>HL@^JHfuLYa!y(J$$3>4I6LCEaAznW;Xzq{48mM!l z^7e)>T^g&&KntRKz)aj<+T8M7+42luvazw@WWsbK#um1i%=MpVsozb#qcN7-msU8~ zO_bSQ`78tKJF+bX4vfe_rBV4cZ`JQYuXMyK!feWsXtF;lB~0NQa5&1~&7qvNJ^OB= zw)p`834oc7#XcQgi_5UA3iYkzrb;@Ad*Xnmm-eS5|oiMYt+yShz(-e`J={;izFbBe+3usQ`Cio|sv zpP!iv}{`yUk!!DX`$~)_yNMQ4$y;P8xD)jRLte6!2v?+6kP<= zz$jBH5=&!1y{{aoSOdw8wetenPQe+aA2D&uDWq?%2iiH}c%&X#bTcH}`0OgmGu~em zWYH}MZ{98RYri`JOe~Z+MWKNA-;|STW>9*Y9W2cU^eu zk63o4K>)1<`H-^Yx_b52R9txCNq@AMjQeI93bNq@R1pCDJzp@paxh5KQSSkLGdIyI z#Bm&Z+kLem7a2bBwogq)8TEo9400Q`fQHU_?q9MA4K-dn!cMd50HJSya$$Qp)aT73 z5~_@!Y8>)5?HQ0e{nZVU8X|t3lan*{vRB3Bo%zD(5tC+%?ypd?0u_1D-Trscb1~a@ zNqx^1{Ew_Lj#>A5^;(xAI&njf1(AE-3a$hP2cv1FhLI6V`Sff~elN7|h)HhGq`HI7 zZw&+)xRgNN>*KH0mY}%P;d8L*Ggs8D99VAKXR!*?v2KkP0Odp{)I>AUR-4ih@5+$N z+cPbarNho~pk?GmooQ@_5VhrCvR|9w?`+?Tj24F4gSh*;cP> zqTM+aQ1)lv{5IMD6`0MIuv5A09$#Q0&-TSCiW?iW11BCs?F&9P3A{xM_N996kzsR>uixk2YTw` zYe;ST_|;}X3V#hOJoxZQQ9AVW1eZHbGx(b-(Px}NxmPkKX zYA1&FQjDWz)?M!BiHC)JF;jp1HKf&Y|Hj~3G3sph^!8qls9AZLuMhzQa?iJKbgOTO zDH&g1_S*h7rPG!msRzv&S6YR_=J)RIgFRb=9I|2}Lh&n{DVV(yl*5cBrlCDS6MS`R z;kmPCJD}F_ZomH(bQeLB7aYGz4h+m%3hjsTg2jk5Te3LD6Fz+tVg>j#$avWt+Mx+y z0xQ10J(r*9*@{!w(>o6hY8TmhzOCnw~?GmXtsMY5{{3IZu zd;uy|e!8zW#D>$d)H1I}Z8$ln?SfC(oC-Yl6x50;M4g;@lRu zQfF^b(Cz08*_P&2t$ZqzYM-zGx>>gs78!P}4>Tv;lE)Q)f!h>zS=2|*=1TmFHO&o+TZh8(jEc|pb0}k{=0jl`swr`3r_0M1FTU;f zC+(p~A`Oj>;O{Xi#Lq z=uUt9-hNt_tv%VIJ)wl0k5&kK?$%!apU=TSFGP)qcJMguBq^ySmQx~M~E z7#tnd%TwxP0ksfSR8XJS18^lead5p~DNghTsHbah?;G--JJ$&sjhyL3A464tsW_}w z^0#`3+lTY}Y~Dak#&tU0CFp!!+55BH+`?e>U6MsFP(79P@#DwoApQzo2E|j0mW+(d zM1oGCIX&na1Pncw<;J0nO!!(T^SvNuN%7!R<-{2-iK$3Zb7h_=wtA1od0NsU7!#=D&4vKIwC`boEe_K?z$Nb7&O2!MuRnV2yq-R%sM?o->*T5Qm8-BePLM8NL zT2P!6ZFSvS?+b2=6FUi-12L$=N92W?#r7ryk1UCuR(@s4=rVM_33_&J@7rzwVnpNQ zF_mtZM9UDIG^Oy$rJBK+?#$59W!E5JYZ>yv{we}qI7EzjOaA)tDBJ^vYqmRYAteNn z%#vM7YUI}!ml{(~5O&~BVi`mZ`@#{1*36Nh`bQ)u4fFfxF4!mChZSksMcqBfzJ)AJ+M)V5p104N)PPJ)e|#?6W9bGwd_KouH!$NVqp)Bn~X{NG^|UJ`-my#bDZhGbaUOpmijhwUri+YeoU z4l=YE5ZWG;A|^I>ia|6PTO_q{C9X zqg3l4i9~xRUqIgnwCzA2v1z0)R4Q}9SkMMe27Y5-=+1RoZlujgPQC*D$}H%yWrBc< zI$#=~Y~;P2R11!)ol;Tk=+4U*7^Jq2_-1EHY?JOpOfKgRxsflwpu?EwF`P ziRr36*jqB+6cHAl0e9X7=fden9p3LYTiDE0^w|KSw!w7QhnKvyz{Pih=$;dU>VZ%} z5_J1!r=`iNNG$uI{khe98-ueWwLaIS&iqbSrB9NPkwH*-k(kJ$voZYERK9Qr{A=;; zMzVP5yNpN!PonL!JI{&x%BINJ&W@-RH1d5O? zNzj@CN>yn~!uxgK#ezypuJs}TQr~`chJ3RHuBVt^%X-Fr?R+P4?sV@IY zLa#L19iv6}{`8I5G!WRxp#EoalR#PmwK!;mhxY9M5;T+g{sZWTqTB!80{;&N{?C`R a9*~sVwI3Z}*VcsRAQ(Ai*}SX95C1>18+6P7 literal 0 HcmV?d00001 diff --git a/examples/advanced/kaplan-meier-he/figs/km_curve_fl.png b/examples/advanced/kaplan-meier-he/figs/km_curve_fl.png new file mode 100644 index 0000000000000000000000000000000000000000..a4765d56547286c2b1660ac44edccc550b861838 GIT binary patch literal 17235 zcma)k1z1(<*7c$klr{({2|0?CgruNIY$=ftX+&ud5J`b8EfTs_x;q7>8>B=)8U&<7 zq*F=|_~+7l?!8|;-+!OSbC9+6Uh%#&#vF4@pWalFCp$)e3_%bwjKcL>2!fAA5WETE zBXC5h`^z`@A?hfr?Wq31)X~}4;Xa~l>}Y5Gz|q>`E{oHBhldsqYz25kcrIOJF?V#d zdnm@sYxBwT zGg&LMG+EF)NUobjwdwd+*Un!g>95$?`C8j!$C)JrKC!>Nhl_2GZFQ|3_^9*kTYLK_ zi{lB))wSbGe54JO@LbjnQL}jPlck0j3qOQq5JChgpu=ZD5Nrn}0;jC~5n==}5+Ia@ z^D$=;d<6MSiU;4VGyT85QYUQM5uH|5Rdt=B!7;&OD~f?UVBn2@%xn^(Kn>RlFK}AY z5oHL?G(1kjqqMQAWWQ3qeNHgdX6T(Pd4P{iw@Sh}hO=6Sg;GXM`_aP2nG`-X6=z&Ar$WEyUF=V-b-Oc15SD zsj1_QnW^bHBSEBH8X+_~KYnRp9E->P+Uwv)g4q5>Z-1eg>gMvO?g>tnzRH!Fk;NLgrTOOq;~R>Ky#>Y%MV|Y+ znJYBJghh9oj>S3lVBomAic0AB_NWxNT`kr2^z`)HuOFR*KRTuI)l%fLJV26CNd@o?Y3}kbr#A%mxTSjZL5LYJ0IKIH4?6P2tNF-v@~8@OQ1@PjpHXEuDg|; zlcTArdCH!fi%Su=G2ugD2@iJR;>Esq_uDiK4f8c~v@Cw@thwzimC6t?@m8&R^Hi-~ zo1U4;DJ#=H{MqRx=)Psm=&{k=apl3MSW%ryr##cn_`EjWTD?l=l>y1^FDa$tp1S!p zyVh7*voP9Ih?s^-qkl7~&G6s>G`PEy| zr@3`fCEd5on}SZ^$UDUjCXPyV4dm$xIWOzQxqN-9Q|no6`YCqkjhUXV?#LpYB5rUY z#I1IGCCtq$j7hSu(5%zRLl}9+ht%Q2!kH)p?cS&LSJTzC+L-HYDf@;~WHq?0re=K4 z)KpH7ng6bDcfO&N3G6_U6}mcSTSdM{np_(+aGgFSgIil|ppdVcX%0CPqb~)=Niogn z-$SZvYHEgx-@m^$RW=p+tjwtg!%1frYcE&m{-9#+T#8v|d~*P^C#}85gqKCqmJrIk@|KM+^tXeQU8h zlu^9MZt{H~t6yQ~72Tgde}au1*LVN5*77{wsUZC;zPsyj4>kL_o7 z-?lL2ynp}R)a3a1QEh$wyyj*(baUJee$Hs>>yMhgZJY{|def^ar7zE5f4x!UoKj?> z?t^Mg-r8U9(v@S+6n(``ec*nQY`vNjtD0_yw=PQRTVuf6k_av182Q9%=Or-oJNG0e9u_*F#Lt z^RqJ1qXLezH*LmiB)>}LQ`O!&-?ls+g-s2p87T)0R{{Fk`6E``)DyS{|WWcRu+cuM{JcEV%D zeY_WfSK@SAgjF3D?}oB+|HDqn45@>)u#O-R9r~6bvxGdWH|6K%J5xFW@ukfm_Okn3 z^*MsUy(=u#jSCS2vtsuUFrl>_e52nn_QUXSHv(*IaMFoKRcgPe${;`9ZT>ts2cM&4 z89U>rNxSPFVetLAx+erSpWZP(Z0Eml{(Scnf}^aacrgbCNZ0A%Iz-`^pP%2y4bGx`;!tL3>RS_KpQk#d9J@q_^`6QkRD@e zx#f8!fj52Y0+?3M-HyYW&#TU9l5bh2j>ylRFzw^qh^6FNCD zVk=cYzWw;gs?)x+?>XwRJ!%tr0(lQnisk&hd-u%yvo)iycvPrJZffjrWJvX^$up*{ zuRBQYEf&&WwvAmbpJBq>vbTSGP9-*TZK`EqM(VJ9YiT3{g0e1WoFk5$L4?w+Q{a<` z%){^PYVbj)#m`9(KZ(kpKYxA~-RhmGu!T?JD}Ayw`hD|dpG0RE($do{dr~pc0%oD` z?Zmy6TEn5Cp<#&ljnza%L^{QmuRRa<#;sN->OX$^G;qf(E-fd&Mn_e0^VKJ*!^e

Di9jBH`$8B$)iELr^l9=x_L*ub9->(z+?7RWSzInQUC=y3ZDE=qd z2tSTRBgQ%9=+rM?^h6oRZ9W$ybqrbqILL^1%a&jYJ3_&fTWmGx^iu>d3hwnZzLCCu zoIeHgo7Y^on;j3AM=OQk@jbRH$NM4BIX0cTJsVN`>wcf74&3zlMSY@0I}ym)Zp!fz zb0v5>C^DcW!vC}ztHO|_FVSn;pVn+dVI^+IO$E<*Pzpe zQ=4wnn{G1WNAdBNMoLxXnOyHZlNW+dYx^0~X6NN;+SUVl9bmef5vE2|re zN6-hUnY{YS9D&hzjfdu)J9PGW@7~>l?NKv3WN)2X`!Ka8SQ-D~MfZKYo489Ho3Ay- zCtmix z69AIMzLEW(7`NF$mc{8|L3ciZK>~a0lH*_8O|K`Y@DS{BC=zM3^0?fPS!>`1u1C@~ zggI807%;Ma5XS`J!^+tCX^dJ>Tb5hCYK?Row;e1=dmX8fcUrmL{0ldi=KOAEg;jI7 zD&~}vEMb3Rb6}CjY|FbMNUTmQ;wvp@zz;Jv4 zEJ+tzQ`!rHJ;BNY3(Yuk;oT`I7Pt*pJ8}-&FxfTkc}3@p^Rilg-$~I}@(zzoK(??J z!dDR`%7BKmVIv%>L5k$^$?~uXi@qF2@c92Z#7AF8(7Pm3h76B!b(bAun;~Oy{taOm zyr<)3Lg{C&CeE)VmK&0IN2dmM>hJhU!xt{#CpK-|P=3ML#F5{bkFFZ_#q3#FXj^~8 zoZlkrO;@mCiq!LoZdkJAKvuLw`9B^g8+)RC;fJE58+e_4BP@Ga9C@1i7t4*Kj)_V9 zj@8}r3p8NE#gTYQ&x=k*G*f)S@fu~;)YDWfXRHZ;P244{BRnSj*h!390&Zo*brBC1 z;{PmfxSJP_&)!Ti9Fwah^$F*LGU##9ledvvOe0jgKkV;FJ{ou``t%GbTsu4xuL#R@ zo+pIibH>o7YRP8QI_0gVK1t5Mj%E$$O)95&&EcjE{u{%J z==e5FK5tpV!8#q~^nu0TX{3gm{U78v z-}N494B=#0uG?=k3=HCXdU^y++i7h^O7i;mZs%$zCCLP2L%~DOXM}xIwSI$9%<(Uk z6e0vsUI`r6|=Jo2X zc(kxFGJYWZiR~y9t5+0{)B&toz$x-{FAf#HPDse_&$-PdDXC8(u`1oSXXw5pgeGAC zuy5A^Q9{Ny2SjZ1i5@LwUqI%Nk@5NX@uMK1ye^@P-Yk+6=Q4|m1m~?(#FqqLzkcn9 zKMu$=8-=bjan1&Pd({T{mX?;RHDH{J?;Ygpp|WaS+2-Qp^N#vi5;97OgiG#m%wPf zx_}+bKrKv2t`>G6kvlvF?-d%dJ7PuEQ{+N=X1i00p$vwnw%A%6!lgnQ$n-iqxWK~` z&aNCSh-r)99sKZuO$bcH1L_T(DwnsA>Q5f?)|zn=d$nwzGq&(@a9g>|YJ=z-#y;y! zeMXOJ_O-#@-{Ci{=+(EKeei`x+LsWOs7tbjUi-!27w;qo1sw~a<-Z15oAi|btB45s z_8on%{Y?iEhKhsTh0W;*!}-jVPyl)_9e;FO=A6H6Ui#mH0tM}jN8+zzqIkR}R;|+q zXKtf@!BTE|8Jt!~B>0{#+Y-KzkPupBKYCPr>Cz=?NrwK*JFjZLxng>hpvne>k;Fe?g!Ak0n~yrK9#0xc(HILsUg!D01^J-fQ_)0_n)nFtx{`rz zk+6-p6|2+(R{rn>4{Y~;%H%+{dFw;wiNTG`-P8%EOl*mhs?7f(!n2-Z$e0g{l z9GNV9>7PtfIJWYsk!-fE##4BKe)X7za_t8;@Gz z2PuRZJmW?%6ju?1@+o4j28nQvsdT>jiqI}Ww$V-X1;HBMF$lcrRwETbL&S*T|GSe-96Tcf_F$Eq_Y#@C;B~d z9N>AJRo2bSXC;EBj6@Wa{I4PUCLyk|s-}U6iK$dQwB(}?yvP9sEHt12C7hsOMrpXp zM@~X~1RGgE7<*5`HpBn>Mrc_{&O2ejGhk+2r%jQU!C+=Rn4jyr`0yv`jx>0mHWl*n zc|-kAh=Pez+V~Kb3g}Fc%BQQ6SPdcq1Uo>9{!)#t&o+Wjn``5%-0lA7enLXNJ9(-e zmw61n$A1<%Cwc+hXZN>+u%~~6Bx4d0u!e)K1bvl&Tch8c;aNNh>)0p$4iV}phj#Vy z`aY_#NGQ1IvW2tH!@u?b*jOhb#q`u4sl6eUg*!Z~fhc88fp3IM>;XIw>9;8Z`~Nqt zP^=TSE0~O#H!gkI46x%fBVIuayrshpu?ZWEIZn-_1=vi04y>?7Mka)HwijH^tv+8g z&~aHN0V7@C{gu7Fd#%3_mOuV8!RHW{FDl@}jP^?@)knWMiv7zD0T-5jg{s5xR!vn6 zwyH+j{4}Ih76!PVy~tnd@JM-ui`1(hUvs^Vh{%8p6hbE;H&S9_wRwQDIAc|=^Mm;b zQ2PM8;{uG^LwkB|NZ9KI-$*AQC2DABq(b7eYLDdCd3)FI*zw~~AMpUGc)7ARDd{TI zpCv#gY&M=yMPuETxX8zBOURCs!P)4GBCcgJ; z)3Zk{jLx&$wTB*kd(WOd11^hh7lj>jwE?I8M5)`nsy7PrPe#0Cf-vye7XW}fAwC`~ z)I!bPvijAlPtuj5k{8}os^r|x4|O|)s#il_pK*8h>oeZ35?>oJkHq`g`ui83{r-Bu zbU3!~-MeoN zZyLSBgWR|UkZHO0&~tw?->V-ui+n}iYDi0ulCmct-(ElYRL^ZLZMgJ7!s(0Ji9kr* z{0%9}r%!jpT;^4)z1Ac4d14SRUp0U}<@Paf*l%X#97dSvOZS6n4tA|jKJL`1Q}n=m z3fq36y+C@Nmp7vKr?L|xtYy5{fvdjj^lP!{SL%JWUS8bV1=3&_Hp!lTWvjd70fG^| zNgB>G?;iIaGe#YIwuZsq_;bx-VS2h;R8&;Si=a&pN5z3DO`o;LesH#VW{Leu=0 z_KiSFm(1?`A*=CM?0Hd}OOf7sKD;bmbYRpjoGa6W{2vv@DvkDDYWy z93gyo_}OOoB$5DrzKUEIP!98ZDMC*UHH6gcd^8I9sJw{cJTgp`9!imYev5hWuWkU&b*tDW5GPcSx41$qB{4L>f3;v-?(ir z=Mf!0ZC>w9gz}B9GjSQf)I@Lco6G0~DNY?Vo69}%P|)dSxLbOKOw*aG28PlB^!pt( zq&~^y5L@buBRBF7y5*Vk{EpLD9{l=AHF*1tUKWa)4LyIR(J}XO03?3}#5Hh`z>=`L ziKL58Y|b@pbM?2qWDSssuHFJ_Evy)`G3kj7p%?#4QSz3)gLp@l@DT)gJ7?|0Z*#;B z=zTGWt3cw<#6AS4w;5_cujOe?169&FJpQ;4 zh5dlFZ{Q5P_8(R6&gblQJn-x-S9^cbRLDP^bq~Mi1SPm-PZqM%B6Yz3b&oAnCc4wa)W$M}-G#rl#1Ao9`#7-f=*Vzn+7*W-8+zlHUGUbO1uArjdLoBAP zt%Oy$QeTaIt3`-d#1Fg0jay9k>muw13QW4wl(Zv14>gZTr8JKSV1WuKcQ=+%WIf_guasx>VpExG)R6 z)#b{Sfo#njLFhSj6(8p*lSau|PvGOI0F1RCGlo2Glg^E|GE^v^5$B{+xB&l;hHAlN zpatwV@ukvntj|%30h2JvfyG@&Z-VzeP!?W+%|hbt4O6cIZcXVABf`SM zmS{%{Y7i#3*|P#;)Nn?>@fBU&t5+xY_gB0wPvt#N{VyK&lnG4I+|NR>?j@({fIr!p zb~kZt>>y27l3Xph(5_5IrQo!Kn8x6QIs_xPwf!5_JDzDJo_Z*8RQFzq3!m=Ny+GTl zH92(C(a?|fay&hMPDjLSss+AW_3z?lB#W4Lx^P)Q|DbLXJG=Q)b7ixYCXk61K+@N! z)5zg^L~W*O!j%hLT*&}v6=Uh1tTLut>IoPHMe1nUWC3=6yz!{%mr@leqi&mfT%}%vZfwD z8z5r!MM#-t|AkSn!~JNnN0zh0sJE!Ah|0qE=Y4A?&Qu8{(fcDTg-dn@`f z)jHnLoILqq=10^fze1umS^!`-qe|+nb~omLiq|Y$*dBFGfM$~EKnu$!Z_RpTq8t=F zu7d~P6w@PIe}9~JR($+tDEnN<<*}QtuH~lfkx3|%>H&4<A{gMc+>pj`({57ghR7+?f$0m<#O{DrO}jH6-4UYzm9b{Xu0pW6=r@~>i| z#+E5*^9~^&E2u!AgRB~}sl|u+6V3corbYfFqa<~4ST)J03PREaOkewGy?;a7CtJ?K z00`1#c$Dc-L_m#@wej)-=2k^Z9waj2Cvb`wPW@es{!V_?gETtcVEzNr$n^jsE}s}- zKXP6^lG6#W(hDp|a)bAhk@+z)SV{goDnVQTlpw<5;QI@G6kQHLbqAqDvO zEEP(<|MnZJi${Rjf}V&$CiKdwiq;rliY1Z(D`l>0Yw_<43XE zre#`2#sb7w^EDyQTEpISLrW`N_StE_H==PW%8`6m)6&woYnGeOJkW4*D)C%&O*wzhU3WqwS#DTGE|!-I92rO$8K(mPk(s@L7BlUA%O zEyW}Z#G!DekM7llLYVE%N_Ot+L7K!m0$_UH=^@tn!^1kt30L&D$GvJH7T$GJhVGbc z*`z-Y)1B>^zYCLwj=1U?=|&|M2pZ?$89+@@uE(RPr6s4ApPQ=zorIXt!&c6OS)k&3 zflIkIi2tPFN#l}}6*MO41ZOlIW z9;#(tX!QIG{${>b{DXhLcP5N49-L4SxeC2P=$b6#*BU?rs^Z5d@suOtVJ&Q7ExjC5 zE9B8k$2nVzQYO9M>%J%a!)`vKBu1oZ3iXixkdV^S(TRo#etK;e@~OhD!i7L(0A;v} zZ{}r2ijfoU0NH*&DJSrh>gh_|gN9yt9E#zY$;sE)v9GOXh}8n6e_tqYxMbD>;o^6lpO+i5AUe?Rq(k!-3B*XU@h_(R62fi|bG^=u!W1?N$eH?WIOimanW; z7d{6EO{{VltB~L!i5?Ji&J?0)MS7;-58u#^ygz9$epxwMUq!w!TOYQCkx<$(2>36E z$Zq+(MKt+;ij&A|MlMte?x;mzL)nmI^`2XLK&f+S__-I<=c6R0a?A*|0clTn3EVx= zLpwtzBQ340uWd?;H%wRZ)(Fpf-~hGUybJn%U02!>)@3kVg-dDNN~;>r(?vK45$tzj ztO9YN&U-=0Q>o{FKd~r23r+{lo4Z3dPb$L~9ES)NiCXWpt*We@Rey6VH9jK!kWgCt zh8=HW)T#Q(J%XVVfc)?%PgCX|8`HdVrs8xi@v@Tkvl6N2l8{~~5o5eb&3aP}!lOc? z%pc;V5riCGN0Bhh|JkOt2&_(wVC{* zgJb`Gq&G^)=d|ZVJy@P1|8V}_cQKx9Vlgd2EoYg!mj5{M32TyGjRMrXW+hA!e5V~X6(hB~SdSn^M+v2+ivsbG&&Tm7^?5#@)CsEQcihiVKnG!Ai)3ZN8WN!uyh%yXvzi^j{ZK4f<%-uj%3Ld$Ayi=y}CQ z!nDO&rx>bCw;b%$7bysTH>!+#($3dV zTg>qmzeUAR!>;ra>{!e{WrFfD+(phy`e;?iZgzzM8KcM;cNzQQoRK&k6fL22(v8Y4 zns)N2TVqAR?sArO&MQ8nMWio*kSeud4)%@8!E*6CKY`X}jh@Rp;Qn~x;`M~v&-+PEXMTxxNizB<*twLmRfhohc zD0S{YdJA}~TyG>W^_1rY1@*}cU*AKCF~#@F5rwFYpn3gbT6Ie!kt&ge^3J*)YlN@Q zEFCeTAP6UtEYuXtj$KO)j|ezX&aTucjbJswc1TQ_PYgP`mZfn_+UZwI{PQ@SY3JkO zB8k`!LB?)J;I3GNShv8f$&PZ501sI+8bZn(rf-aT?k#0!W_Ce_egsI)YGbxrZKe9~ zU~jXuc3)ZCdHEIlsgoyX!*y*f);vF}=ALSJ`X1c^i973G@*XV=WMyV5fP|+X$S9%5 zJ9ei|pSD~Xt5#bXe|P^`x=LI&0I{5+A`OBgq+Oyju_>SqWa?dUjq$f73zx;27#02Zl9&h5w5R(*lWwR=z>r#)hNhY8`>vir3C$ zz~+MCp$;1UQAv*!2p^6S;)A;8>5mV_{zkHPz*7*6$^D}kR=<|`8hEGq7xR`|)Aj-j zM-jdV;EeuvfWTiZ@8MBRf9A|)-UE-E9RH_0_qO=$M`i82SoWZE78_X{84v-B2Gng`qk+b-6HcBAcT=IG<>|~ciu$a;aEwa zg>r$cQZnK8(Ra+0h|v+~U8Sd|XD22~13eZk;rfo<@0b{x@xLzERxSyTf(8balc54l zEvbgzDuqZ<>&C0Au}0^&3cXFJte(+w@lZf+qyqGSLut<_0WDQta$Km$vg*^0zGN4(6xs`8F4f3{p-+d-N`QC@#9FKN$ zC(?DqgxF(GMFaqCMGUP)w8#YfF@WPx)KF7PqGFRX2VySi{d1^wMB~PPHPk|hBP1v| zoBWI?$R#07F=F1ic28d-Ae86%hwO=aqE{iNZ`slM8}E&N`|r)8qDkna0h-kaYC9h` za(~{zPSD@Ds9d7Tbf9#K8lgaWosrD9KJWV2iuV)bkJBNuOky$m7ext(VLNdJLv8iS zP46Tacx@Qz22bmJ02mFq;^q#nU^M!AK&h_CEpmM6ZBPb_WwLPw zvb$l*QudAxi4!BE5SvG@?eN=AR#eCvxvNsX?ZJVb5~RcmmffLGj}P`s;_LSGtQWiv z_Z`Ve`$~1}IlK+bm5Bm*}M2{7hi!U#Q^Cr%& zEId5mW!C-RKDRYuB39&iEEFUq$=np2mKYJeOGUvJ*~K%mPlyqAW~fc_(%TxI>QSqd zg!3<*vU+H-bsQfnM2U3g$r^fxcs1F{>oofo&ad0K$0!yWQX;ue5aJO9$0j>>-G3a9 z{&U<1ju%)q*}dbXDCOk)`=h#V?_z@Q9wsxTiKTV;8f-p~ zFgG5C^)q$4fH!(}tCThjo6% z7Nn~hF0P^wer(d5GpIBQQR3(_Z9(Vpr_w0|u%Flve0}y|$SyH*&Q3qe($Y)>!UJN2{cTCeQt;{+2&EIB3FZ}pWE2YMz?b(V)kVV7rnI+D zJ2}-Fxo;hde(Di`RGh(Qq2zq4K{|MO=Np9}1)WwbBPq>-l9Ac=@#c!#T}q4Mgk6mA z*dmq|N6yok9To4Fen(Luc#WEhOD4Gcs)_wkYPr|=NxJsoE!g{d>>X_0UL&&a-QsU8 z;q1z8zo?ZINjQ6eUfCkY*ikq4*3~aBw)oBC?lfIE`4ZUzEj&)aiY#TMXTbhp-rM)> zXNOG(yyW22C^H_V!X?Hy(S`QKBu!ItJmfowZ-u+NhZ8MYu4{ zU(e&SV6}=%9Iczz?Rw!P$pe}bZMWV=EtV{{8>lAk|G9jkvQ0-x({MyV+Vt1062GU; zhBg>*LwuCvX*bQAr4)xfJu3c5?9vJ1BGq?GaU7ta_SH7cA9mk} zy|Z{}A@al27WQ6*wY(hqjgTz{Nwe_l9b!(`&CD*mOj{x|;^-dcb3LN4PBsKApd>GdJ#?!M*IJkdq4~B6I%bOx&(G?VSVw`#wn`+FR$8KQ z?@*u9a^Njndt{bUNy%mGYu^WPB0M=6pTzh%pQKpCaS;4}8kmR0g9g2QS?ZdAQDA<; zFic`S@SJLVE-Dv-84QvgGQ^SQY2>3tVga)YOJy@uvzg%~DvxzEGe9W0siKJK`k*tz zcTq%M6qTIpjf|rc3v3vGZ74s%Q@MEQPm>cWd^OOK#=Cxx>MCA&JR^?wzX1M?exCu^ z=3fE=>Y!aXD^)BVlN0@oqU?oxlGPKu{Ll$+AgSDjDi0&MDmlm2cl5kX`_b|N5!}3R7Y0 z&t0!&h>$iutMqsEG+kTkhuNU2xNWn2pU_Tw^-zTFv>`V!kUF{9*|_nQ+bSyGDqR05 zJpzOr=;kFWEBjljRAumBb`QEDn1EhEvZ+EKwd!AZl4lzDG~gCrzI-9C9)nv532B17 z85Obv8_2;(ZjM4oDKF2y!C*#lE%*hUkGVD|41*7-im7`1H+;OVx6qprM%NIAk(2Z1 zIA}AGm52#nv({-Yw4m!uzyIeK(=@Z*7g4n=ZA8+T3l}c*v_}a51pO<7Nm7SPEB7W? z1q{4!W1{&@{CBrk@G38y?g$eYv!UY4U5!sqY=XReKCg7F2Sn`=9}IzatHS{^ znE~zO{MrK@5VNX5A6PD!HVO1-{>LO?++&*5O4SZ@)N?Q7nYy8qlRtj2RZ0)-0k8mR z3Mk95f58rtth zNvj+Xu*j@fk?T=ZqJhr>*Zk*X7cJ-i5bSYx%LG-z+o4{f1%(w%P~^3~4m4wGN^W&1FGG6;M)i}T1>z&SKwHEX%R| zFMOnEUm2zg)ZsV??g(L`NS=hj{xOUKpzWm=Cih&ojoBo0Z~+Dkf<)*)@^pTO9?t+U zjjv*2xB?ow7ccH;$2Wi>3EB(`LaPt$gX}i(dL|JZC2w3Sou$u{jDfKfefOo}-@Pr6 z@lSvMZrL162i=x@*0F}r&h)JEhch2R5d_0SRD&-!hbWOLqtS{-=)U))DHeO|Igv46 z(F0-eaHaEuuPk3}8r=q;B`ku%yiQYE?n zE`c5m>g!ay6_2@0k!>yxsl#)K!B7jn{McqVw{AK*MiuYA+>Q49p{kiH8L{mhbslG| zXyNHh2fbfI0L37PNKy_Y!e{tkjjH#SD~F3LUP4r(tNyj#268~pQ~8ETZADwwFa(3X zExPt7YyML8N|t`5)13Elp1`us(9UyTb~hI`VI)Le@)4UwLdX-I%WfOyCLl#gTG{uk zeGwC51+jtCPUC5*^NIlsLiH%}()k4Jr&_`utj#=;uA3bx&9{eXw179D#2<%QhvvgQ zuR~rY8KualCphWRqMPl_{;a&zCA1Sn>+g{e!fpbFm9%2wN{%k+*;W@=dYSECKN1cF zK@GPCR@7{>sGm1i_|y~{3;yg=kV!6+`k4KwJ>096gWhjQoVYW63&X?UHMHNlQFzZ& zxE*nN`;UPA-k0Rx12(>VxC?AQF7<)Of-Wu#lYOc4OANmJz^8knplC%y#h!N8L9DV_fjxC)_7$YngA>9T}tS%$cRhF;ix@; zYrv%*{!BRRhS9c7Xhs_1suPsUwBu!9fC>6xDKJRrw;aW1>>aC;0p=1?7AA_;-_S+4 zCZ9U|a|nbjDCG}hPsJ)?bE&3Ms^85_hUB_NWi>Tl|MK*jJ^3;bW)eJB>yMJqYknod z$J4iOItAvbH`^EvH8PCG2#ZIEa&pRp7=(oO{^Q@2L5)WxI;vtSRrsS2Rp1mu2@L1w zfWXOv6UzF~%S1<3>+O#_7VKfs}2nju!@9T0ch1dcEk1t1-24F9lHfLIYSwpEA2$~{` znNM-U6_2cv1MZTb%mXTUfiZ#e+}zyV4R2pbvn0Yo%>kwKf*lpoS_^%HddN#J`NZXzVNp|=C* zMjd=6qG1w*(@HJWb~J4tKjPY^Ew%FQ+_{ql0qUH&##14$gU6CT+69JTs3-Kw?rhg4 zoX7^56u3xO**JM;c5bec`}T5wR!01V(h+W$i-h4gjPvNDh%dIK4h|01d)DBC8NYu0 z3bcaJqGU)K9OpOoM%KJc&|N*vqnDX*u#xe5P7Nk*P+1PDv%R-sc7i8-rVFyLsh{hW z1l7H*rNXAdWB#?BQipmF!j5$uV75Gj^Q+c_ww6b8U@*-8z~BHR9-A@azvN&f2#4E) zxLB$EzY2uU+1*<+qxA>cCLfhQG|?FL!9>TIwIANz_-U!B*ClcwjAVgCY(y$&1*YHJ zKsj9mQzKo2nQAEzxQn_HmUgYt0k&NIE|`JLgwZ59oa5r#mh1=PHB_Cz5Xn7&4tc4U z@aE>GgBaT(y>q!9;5FRhzV-X2bp}*S2C6u8*w?Z18E;J{>@lnC)es%RMU6aTt04HHHBFkgt0OdK3P3o(wTf*2)$3G)Y47K|ucRe=t> zIp1@WD+Mx(9Ka|v0l=nYy@-mkfDxpo&yZ4Pb82d?ICetQI0TlC9<d2-j?W!7M9ZsN|OR6wC>cFt`rJh(1j32A$Ou47cz*fx<%qL>s?Huwb^!clm8w zVHQNyG0#04bl^&uoephu!>>y>Ijk;rPJVgVDI^!lU{ST!e5M%!Nyu3le_EK%Rqo5b z7d;TK4ElHbdLkNH7(OZj*c2LA8&eISQF3v|I@v;p$eOlw4;;L#bbr6d^DS$Zp$3(kF;OK$3u2 z>SmaiCJk#biM4xA(hOLMX6=VGi_t}cs?Q%bAGs87bh*ngM79 zqsxqr!`Zhmn(2XmP#xNhOb=O)Rf*1z^TGg}ASqkP%xc2nua-!$N7>ylw#vrF)({B| zT|lWZ*zsmP00yI#&ZRrf0AXTeR*YNBPAC^4k-44C#3%hGW1JlG1F%E=hTH0I{! zDu8%IC%Dl`YKRTJu=GjjShi=a)W#U7D$%|-3RvSM-7P^H3UsgHT~9Tnq#i2DqDgcR z5_LaR@ByvqSL`TQ=vdCl&mYux9!&#_(16A(U>VO;=7ZIM-}iQ37WLSj%cuZ_jgajZ z5oM6zqg~F&FnjOr2KN0rCME|B3Sf5q=s+Y`TjS<7p%F7Inz;-lIk2DC@$pPLTl+v> zM~@DE{+tPM)uJo$DlYXIPlW=-m7PRTwt$$G7begbD#!QV2tOFh1tu1qD}yc^5OXrkI@tKCDv4_Fw%x_TP7Rp5k8WBB zhQVk9R7Ma?C$IrCo0c$?JX~TE6ZGuaCVVQI&k|J`5+nR&^VvBdCw&ETEGeO4H{kKp zt5%|gtWY+4|NG_IZb*tkuIu-ze}0ca=K{~KP$1oQ)Wj&R|4)a^{!0-6-~3y0{U87H d4(L}7nH8GiRkpuc8GtV!7+IC;Z>|}6|35W>@2UU* literal 0 HcmV?d00001 diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf index 9de6ad8d7c..0bbb867b6e 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf @@ -3,7 +3,7 @@ format_version = 2 # This is the application script which will be invoked. Client can replace this script with user's own training script. - app_script = "km_train.py" + app_script = "kaplan_meier_train.py" # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. app_config = "" diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf index 618bb4b0b5..de850e47bf 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf @@ -12,7 +12,7 @@ task_name = "train" wf_class_path = "kaplan_meier_wf.KM", wf_args { - min_clients = 2 + min_clients = 5 } } } diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/km_train.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py similarity index 79% rename from examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/km_train.py rename to examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py index 30742eba2c..04fb152de3 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/km_train.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py @@ -31,7 +31,7 @@ np.random.seed(77) -def prepare_data(num_of_clients: int = 2, bin_days: int = 7): +def prepare_data(num_of_clients: int = 5, bin_days: int = 7): # Load data data_x, data_y = load_veterans_lung_cancer() # Get total data count @@ -57,11 +57,43 @@ def prepare_data(num_of_clients: int = 2, bin_days: int = 7): return event_clients, time_clients -def save(result: dict): +def details_save(kmf): + # Get the survival function at all observed time points + survival_function_at_all_times = kmf.survival_function_ + # Get the timeline (time points) + timeline = survival_function_at_all_times.index.values + # Get the KM estimate + km_estimate = survival_function_at_all_times["KM_estimate"].values + # Get the event count at each time point + event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events + # Get the survival rate at each time point (using the 1st column of the survival function) + survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values + # Return the results + results = { + "timeline": timeline.tolist(), + "km_estimate": km_estimate.tolist(), + "event_count": event_count.tolist(), + "survival_rate": survival_rate.tolist(), + } file_path = os.path.join(os.getcwd(), "km_global.json") - print(f"save the result to {file_path} \n") + print(f"save the details of KM analysis result to {file_path} \n") with open(file_path, "w") as json_file: - json.dump(result, json_file, indent=4) + json.dump(results, json_file, indent=4) + + +def plot_and_save(kmf): + # Plot and save the Kaplan-Meier survival curve + plt.figure() + plt.title("Federated HE") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + plt.tight_layout() + file_path = os.path.join(os.getcwd(), "km_curve_fl.png") + print(f"save the curve plot to {file_path} \n") + plt.savefig(file_path) def main(): @@ -164,36 +196,12 @@ def main(): # Fit the model kmf.fit(durations=time_unfold, event_observed=event_unfold) - # Plot and save the Kaplan-Meier survival curve - plt.figure() - plt.title("Federated HE") - kmf.plot_survival_function() - plt.ylim(0, 1) - plt.ylabel("prob") - plt.xlabel("time") - plt.legend("", frameon=False) - plt.tight_layout() - plt.savefig(os.path.join(os.getcwd(), "km_curve_fl.png")) - - # Save global result to a json file - # Get the survival function at all observed time points - survival_function_at_all_times = kmf.survival_function_ - # Get the timeline (time points) - timeline = survival_function_at_all_times.index.values - # Get the KM estimate - km_estimate = survival_function_at_all_times["KM_estimate"].values - # Get the event count at each time point - event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events - # Get the survival rate at each time point (using the 1st column of the survival function) - survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values - # Return the results - results = { - "timeline": timeline.tolist(), - "km_estimate": km_estimate.tolist(), - "event_count": event_count.tolist(), - "survival_rate": survival_rate.tolist(), - } - save(results) + # Plot and save the KM curve + print("plot KM curve!!!!!!!!!!!!") + plot_and_save(kmf) + + # Save details of the KM result to a json file + details_save(kmf) # Send a simple response to server response = FLModel(params={}, params_type=ParamsType.FULL) From 2959b3fc77ad29fc11a2f2e865beb6d7bd12d590 Mon Sep 17 00:00:00 2001 From: ZiyueXu77 Date: Fri, 5 Jan 2024 16:04:53 -0500 Subject: [PATCH 07/20] job name update --- .../advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf index 5c81903a41..a393deb907 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf @@ -1,5 +1,5 @@ { - name = "fl_km" + name = "kaplan-meier-he" deploy_map { app = ["@ALL"] } From 484cf3d827c7a3c3b40ddba09374f6e0b51df8e0 Mon Sep 17 00:00:00 2001 From: ZiyueXu77 Date: Fri, 5 Jan 2024 16:15:17 -0500 Subject: [PATCH 08/20] remove redundant print --- .../jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py index 04fb152de3..a3c3618e0f 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py @@ -197,7 +197,6 @@ def main(): kmf.fit(durations=time_unfold, event_observed=event_unfold) # Plot and save the KM curve - print("plot KM curve!!!!!!!!!!!!") plot_and_save(kmf) # Save details of the KM result to a json file From f2e6d108ef400d9eeffa01214dfaf06b28d399de Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Mon, 8 Jan 2024 15:57:53 -0500 Subject: [PATCH 09/20] move data preparation part out of local code --- examples/advanced/kaplan-meier-he/README.md | 6 +- .../app/config/config_fed_client.conf | 2 +- .../app/custom/kaplan_meier_train.py | 45 +++------- .../app/custom/kaplan_meier_wf.py | 1 - .../kaplan-meier-he/utils/prepare_data.py | 88 +++++++++++++++++++ 5 files changed, 105 insertions(+), 37 deletions(-) create mode 100644 examples/advanced/kaplan-meier-he/utils/prepare_data.py diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index f168f1a606..45f84a43e9 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -41,7 +41,11 @@ python baseline_kaplan_meier.py ``` By default, this will generate a KM curve image `km_curve_baseline.png` under the current working directory. -Then we run a 5-client federated job with simulator +Then we run a 5-client federated job with simulator, begin with splitting and generating the data files for each clients: +```commandline +python utils/prepare_data.py --out_path "/tmp/flare/dataset/km_data" +``` +And we can run the federated job: ```commandline nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he ``` diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf index 0bbb867b6e..bb0ca87a8e 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf @@ -6,7 +6,7 @@ app_script = "kaplan_meier_train.py" # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. - app_config = "" + app_config = "--data_root /tmp/flare/dataset/km_data" # Client Computing Executors. executors = [ diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py index a3c3618e0f..a01e0a0d2f 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py @@ -12,51 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import json import os import matplotlib.pyplot as plt import numpy as np +import pandas as pd import tenseal as ts from lifelines import KaplanMeierFitter from lifelines.utils import survival_table_from_events -from sksurv.datasets import load_veterans_lung_cancer # (1) import nvflare client API import nvflare.client as flare from nvflare.app_common.abstract.fl_model import FLModel, ParamsType -# Client training code - -np.random.seed(77) - - -def prepare_data(num_of_clients: int = 5, bin_days: int = 7): - # Load data - data_x, data_y = load_veterans_lung_cancer() - # Get total data count - total_data_num = data_x.shape[0] - print(f"Total data count: {total_data_num}") - # Get event and time - event = data_y["Status"] - time = data_y["Survival_in_days"] - # Categorize data to a bin, default is a week (7 days) - time = np.ceil(time / bin_days).astype(int) - # Shuffle data - idx = np.random.permutation(total_data_num) - # Split data to clients - event_clients = {} - time_clients = {} - for i in range(num_of_clients): - start = int(i * total_data_num / num_of_clients) - end = int((i + 1) * total_data_num / num_of_clients) - event_i = event[idx[start:end]] - time_i = time[idx[start:end]] - event_clients["site-" + str(i + 1)] = event_i - time_clients["site-" + str(i + 1)] = time_i - return event_clients, time_clients - +# Client code def details_save(kmf): # Get the survival function at all observed time points survival_function_at_all_times = kmf.survival_function_ @@ -97,15 +69,20 @@ def plot_and_save(kmf): def main(): + parser = argparse.ArgumentParser(description="KM analysis") + parser.add_argument("--data_root", type=str, help="Root path for data files") + args = parser.parse_args() + flare.init() site_name = flare.get_site_name() print(f"Kaplan-meier analysis for {site_name}") # get local data - event_clients, time_clients = prepare_data() - event_local = event_clients[site_name] - time_local = time_clients[site_name] + data_path = os.path.join(args.data_root, site_name + ".csv") + data = pd.read_csv(data_path) + event_local = data["event"] + time_local = data["time"] while flare.is_running(): # receives global message from NVFlare diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py index b775afaaa7..9fc2575652 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py @@ -79,7 +79,6 @@ def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): max_idx_global = [] for site, fl_model in task_result.items(): max_idx = fl_model.params["max_idx"] - print(max_idx) max_idx_global.append(max_idx) # actual time point as index, so plus 1 for storage return max(max_idx_global) + 1 diff --git a/examples/advanced/kaplan-meier-he/utils/prepare_data.py b/examples/advanced/kaplan-meier-he/utils/prepare_data.py new file mode 100644 index 0000000000..951a93213c --- /dev/null +++ b/examples/advanced/kaplan-meier-he/utils/prepare_data.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import numpy as np +import pandas as pd +from sksurv.datasets import load_veterans_lung_cancer + +np.random.seed(77) + + +def data_split_args_parser(): + parser = argparse.ArgumentParser(description="Generate data split for dataset") + parser.add_argument("--site_num", type=int, default=5, help="Total number of sites, default is 5") + parser.add_argument( + "--site_name_prefix", + type=str, + default="site-", + help="Site name prefix, default is site-", + ) + parser.add_argument("--out_path", type=str, help="Output root path for split data files") + return parser + + +def prepare_data(data, site_num, bin_days: int = 7): + # Get total data count + total_data_num = data.shape[0] + print(f"Total data count: {total_data_num}") + # Get event and time + event = data["Status"] + time = data["Survival_in_days"] + # Categorize data to a bin, default is a week (7 days) + time = np.ceil(time / bin_days).astype(int) + # Shuffle data + idx = np.random.permutation(total_data_num) + # Split data to clients + event_clients = {} + time_clients = {} + for i in range(site_num): + start = int(i * total_data_num / site_num) + end = int((i + 1) * total_data_num / site_num) + event_i = event[idx[start:end]] + time_i = time[idx[start:end]] + event_clients["site-" + str(i + 1)] = event_i + time_clients["site-" + str(i + 1)] = time_i + return event_clients, time_clients + + +def main(): + parser = data_split_args_parser() + args = parser.parse_args() + + # Load data + # For this KM analysis, we use full timeline and event label only + _, data = load_veterans_lung_cancer() + + # Prepare data + event_clients, time_clients = prepare_data(data=data, site_num=args.site_num) + + # Save data to csv files + if not os.path.exists(args.out_path): + os.makedirs(args.out_path, exist_ok=True) + for site in range(args.site_num): + output_file = os.path.join(args.out_path, f"{args.site_name_prefix}{site + 1}.csv") + df = pd.DataFrame( + { + "event": event_clients["site-" + str(site + 1)], + "time": time_clients["site-" + str(site + 1)], + } + ) + df.to_csv(output_file, index=False) + + +if __name__ == "__main__": + main() From 8cb6d7dc2a58b0b7e3fc83e69a863f0f693ffe07 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Wed, 10 Jan 2024 15:14:50 -0500 Subject: [PATCH 10/20] move HE context part out of FL process to better accomodate the transition to real application --- examples/advanced/kaplan-meier-he/README.md | 7 ++- .../app/config/config_fed_client.conf | 2 +- .../app/config/config_fed_server.conf | 1 + .../app/custom/kaplan_meier_train.py | 23 ++++--- .../app/custom/kaplan_meier_wf.py | 33 +++++----- .../kaplan-meier-he/utils/prepare_data.py | 2 +- .../utils/prepare_he_context.py | 62 +++++++++++++++++++ 7 files changed, 103 insertions(+), 27 deletions(-) create mode 100644 examples/advanced/kaplan-meier-he/utils/prepare_he_context.py diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 45f84a43e9..58cec6ebe2 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -28,7 +28,7 @@ The Flare Workflow Communicator API provides the functionality of customized mes Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) does not support [simulator mode](https://nvflare.readthedocs.io/en/main/getting_started.html), the main reason is that the HE context information (specs and keys) needs to be provisioned before initializing the federated job. For the same reason, it is not straightforward for users to try different HE schemes beyond our existing support for [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py). With the Flare Workflow Communicator API, such "proof of concept" experiment becomes easy (of course, secure provisioning is still the way to go for real-life federated applications). In this example, the federated analysis pipeline includes 3 rounds: -1. Server generate and distribute the HE context to clients, and remove the private key on server side. Again, this step is done by secure provisioning for real-life applications, but for simulator, we use this step to distribute the HE context. +1. Server send the simple start message without any payload. 2. Clients collect the information of the local maximum time bin number and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. 3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients. @@ -45,6 +45,11 @@ Then we run a 5-client federated job with simulator, begin with splitting and ge ```commandline python utils/prepare_data.py --out_path "/tmp/flare/dataset/km_data" ``` +Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but for simulator, we use this step to distribute the HE context. +```commandline +python utils/prepare_he_context.py --out_path "/tmp/flare/he_context" +``` + And we can run the federated job: ```commandline nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf index bb0ca87a8e..0704590617 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf @@ -6,7 +6,7 @@ app_script = "kaplan_meier_train.py" # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. - app_config = "--data_root /tmp/flare/dataset/km_data" + app_config = "--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" # Client Computing Executors. executors = [ diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf index de850e47bf..2ade684715 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf @@ -13,6 +13,7 @@ wf_class_path = "kaplan_meier_wf.KM", wf_args { min_clients = 5 + he_context_path = "/tmp/flare/he_context/he_context_server.txt" } } } diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py index a01e0a0d2f..401c26aaf1 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse +import base64 import json import os @@ -29,6 +30,12 @@ # Client code +def read_data(file_name: str): + with open(file_name, "rb") as f: + data = f.read() + return base64.b64decode(data) + + def details_save(kmf): # Get the survival function at all observed time points survival_function_at_all_times = kmf.survival_function_ @@ -71,6 +78,7 @@ def plot_and_save(kmf): def main(): parser = argparse.ArgumentParser(description="KM analysis") parser.add_argument("--data_root", type=str, help="Root path for data files") + parser.add_argument("--he_context_path", type=str, help="Path for the HE context file") args = parser.parse_args() flare.init() @@ -84,6 +92,11 @@ def main(): event_local = data["event"] time_local = data["time"] + # HE context + # In real-life application, HE context is prepared by secure provisioning + he_context_serial = read_data(args.he_context_path) + he_context = ts.context_from(he_context_serial) + while flare.is_running(): # receives global message from NVFlare global_msg = flare.receive() @@ -92,14 +105,7 @@ def main(): if curr_round == 1: # First round: - # Get HE context from server - # Send max index back - - # In real-life application, HE setup is done by secure provisioning - he_context_serial = global_msg.params["he_context"] - # bytes back to context object - he_context = ts.context_from(he_context_serial) - + # Empty payload from server, send max index back # Condense local data to histogram event_table = survival_table_from_events(time_local, event_local) hist_idx = event_table.index.values.astype(int) @@ -108,7 +114,6 @@ def main(): # Send max to server print(f"send max hist index for site = {flare.get_site_name()}") - # Send the results to server model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) flare.send(model) diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py index 9fc2575652..1ec5e372e3 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import logging from typing import Dict @@ -31,39 +32,37 @@ class KM(WF): - def __init__(self, min_clients: int): + def __init__(self, min_clients: int, he_context_path: str): super(KM, self).__init__() self.logger = logging.getLogger(self.__class__.__name__) self.min_clients = min_clients + self.he_context_path = he_context_path self.num_rounds = 3 def run(self): - he_context, max_idx_results = self.distribute_he_context_collect_max_idx() + max_idx_results = self.start_fl_collect_max_idx() global_res = self.aggr_max_idx(max_idx_results) enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res) - hist_obs_global, hist_cen_global = self.aggr_he_hist(he_context, enc_hist_results) + hist_obs_global, hist_cen_global = self.aggr_he_hist(enc_hist_results) _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) - def distribute_he_context_collect_max_idx(self): - self.logger.info("send kaplan-meier analysis command to all sites with HE context \n") - - context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=4096, plain_modulus=1032193) - context_serial = context.serialize(save_secret_key=True) - # drop private key for server - context.make_context_public() - # payload data always needs to be wrapped into an FLModel - model = FLModel(params={"he_context": context_serial}, params_type=ParamsType.FULL) + def read_data(self, file_name: str): + with open(file_name, "rb") as f: + data = f.read() + return base64.b64decode(data) + def start_fl_collect_max_idx(self): + self.logger.info("send initial message to all sites to start FL \n") msg_payload = { MIN_RESPONSES: self.min_clients, CURRENT_ROUND: 1, NUM_ROUNDS: self.num_rounds, START_ROUND: 1, - DATA: model, + DATA: {}, } results = self.flare_comm.broadcast_and_wait(msg_payload) - return context, results + return results def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): self.logger.info("aggregate max histogram index \n") @@ -99,9 +98,13 @@ def distribute_max_idx_collect_enc_stats(self, result: int): results = self.flare_comm.broadcast_and_wait(msg_payload) return results - def aggr_he_hist(self, he_context, sag_result: Dict[str, Dict[str, FLModel]]): + def aggr_he_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): self.logger.info("aggregate histogram within HE \n") + # Load HE context + he_context_serial = self.read_data(self.he_context_path) + he_context = ts.context_from(he_context_serial) + if not sag_result: raise RuntimeError("input is None or empty") diff --git a/examples/advanced/kaplan-meier-he/utils/prepare_data.py b/examples/advanced/kaplan-meier-he/utils/prepare_data.py index 951a93213c..66684a1b4b 100644 --- a/examples/advanced/kaplan-meier-he/utils/prepare_data.py +++ b/examples/advanced/kaplan-meier-he/utils/prepare_data.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py new file mode 100644 index 0000000000..ceedf4c9a4 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/utils/prepare_he_context.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import base64 +import os + +import tenseal as ts + + +def data_split_args_parser(): + parser = argparse.ArgumentParser(description="Generate HE context") + parser.add_argument("--scheme", type=str, default="BFV", help="HE scheme, default is BFV") + parser.add_argument("--poly_modulus_degree", type=int, default=4096, help="Poly modulus degree, default is 4096") + parser.add_argument("--out_path", type=str, help="Output root path for HE context files for client and server") + return parser + + +def write_data(file_name: str, data: bytes): + data = base64.b64encode(data) + with open(file_name, "wb") as f: + f.write(data) + + +def main(): + parser = data_split_args_parser() + args = parser.parse_args() + if args.scheme == "BFV": + scheme = ts.SCHEME_TYPE.BFV + # Generate HE context + context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree, plain_modulus=1032193) + elif args.scheme == "CKKS": + scheme = ts.SCHEME_TYPE.CKKS + # Generate HE context, CKKS does not need plain_modulus + context = ts.context(scheme, poly_modulus_degree=args.poly_modulus_degree) + else: + raise ValueError("HE scheme not supported") + + # Save HE context to file for client + if not os.path.exists(args.out_path): + os.makedirs(args.out_path, exist_ok=True) + context_serial = context.serialize(save_secret_key=True) + write_data(os.path.join(args.out_path, "he_context_client.txt"), context_serial) + + # Save HE context to file for server + context_serial = context.serialize(save_secret_key=False) + write_data(os.path.join(args.out_path, "he_context_server.txt"), context_serial) + + +if __name__ == "__main__": + main() From 8313738d05a246c41363c1b40c67ad7e6b5844c4 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Mon, 1 Apr 2024 12:31:06 -0700 Subject: [PATCH 11/20] update to use new controller interface --- .../app/config/config_fed_server.conf | 18 ++--- .../app/custom/kaplan_meier_wf.py | 67 +++++-------------- .../advanced/kaplan-meier-he/requirements.txt | 2 +- .../app_common/workflows/model_controller.py | 7 +- 4 files changed, 31 insertions(+), 63 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf index 2ade684715..2589c856bd 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf @@ -5,18 +5,14 @@ task_result_filters = [] workflows = [ - { - id = "km" - path = "nvflare.app_common.workflows.wf_controller.WFController" - args { - task_name = "train" - wf_class_path = "kaplan_meier_wf.KM", - wf_args { - min_clients = 5 - he_context_path = "/tmp/flare/he_context/he_context_server.txt" - } - } + { + id = "km" + path = "kaplan_meier_wf.KM" + args { + min_clients = 5 + he_context_path = "/tmp/flare/he_context/he_context_server.txt" } + } ] components = [] diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py index 1ec5e372e3..3b0170cb65 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py @@ -19,19 +19,12 @@ import tenseal as ts from nvflare.app_common.abstract.fl_model import FLModel, ParamsType -from nvflare.app_common.workflows.wf_comm.wf_comm_api_spec import ( - CURRENT_ROUND, - DATA, - MIN_RESPONSES, - NUM_ROUNDS, - START_ROUND, -) -from nvflare.app_common.workflows.wf_comm.wf_spec import WF +from nvflare.app_common.workflows.wf_controller import WFController # Controller Workflow -class KM(WF): +class KM(WFController): def __init__(self, min_clients: int, he_context_path: str): super(KM, self).__init__() self.logger = logging.getLogger(self.__class__.__name__) @@ -53,15 +46,9 @@ def read_data(self, file_name: str): def start_fl_collect_max_idx(self): self.logger.info("send initial message to all sites to start FL \n") - msg_payload = { - MIN_RESPONSES: self.min_clients, - CURRENT_ROUND: 1, - NUM_ROUNDS: self.num_rounds, - START_ROUND: 1, - DATA: {}, - } - - results = self.flare_comm.broadcast_and_wait(msg_payload) + model = FLModel(params={}, current_round=1, total_rounds=self.num_rounds) + + results = self.send_model(data=model) return results def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): @@ -70,13 +57,8 @@ def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): if not sag_result: raise RuntimeError("input is None or empty") - task_name, task_result = next(iter(sag_result.items())) - - if not task_result: - raise RuntimeError("task_result None or empty ") - max_idx_global = [] - for site, fl_model in task_result.items(): + for fl_model in sag_result: max_idx = fl_model.params["max_idx"] max_idx_global.append(max_idx) # actual time point as index, so plus 1 for storage @@ -85,17 +67,14 @@ def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): def distribute_max_idx_collect_enc_stats(self, result: int): self.logger.info("send global max_index to all sites \n") - model = FLModel(params={"max_idx_global": result}, params_type=ParamsType.FULL) - - msg_payload = { - MIN_RESPONSES: self.min_clients, - CURRENT_ROUND: 2, - NUM_ROUNDS: self.num_rounds, - START_ROUND: 1, - DATA: model, - } + model = FLModel( + params={"max_idx_global": result}, + params_type=ParamsType.FULL, + current_round=2, + total_rounds=self.num_rounds, + ) - results = self.flare_comm.broadcast_and_wait(msg_payload) + results = self.send_model(data=model) return results def aggr_he_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): @@ -108,14 +87,10 @@ def aggr_he_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): if not sag_result: raise RuntimeError("input is None or empty") - task_name, task_result = next(iter(sag_result.items())) - - if not task_result: - raise RuntimeError("task_result None or empty ") - hist_obs_global = None hist_cen_global = None - for site, fl_model in task_result.items(): + for fl_model in sag_result: + site = fl_model.meta.get("client_name", None) hist_obs_he_serial = fl_model.params["hist_obs"] hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial) hist_cen_he_serial = fl_model.params["hist_cen"] @@ -146,15 +121,9 @@ def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial) model = FLModel( params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, params_type=ParamsType.FULL, + current_round=3, + total_rounds=self.num_rounds, ) - msg_payload = { - MIN_RESPONSES: self.min_clients, - CURRENT_ROUND: 3, - NUM_ROUNDS: self.num_rounds, - START_ROUND: 1, - DATA: model, - } - - results = self.flare_comm.broadcast_and_wait(msg_payload) + results = self.send_model(data=model) return results diff --git a/examples/advanced/kaplan-meier-he/requirements.txt b/examples/advanced/kaplan-meier-he/requirements.txt index 6b15006556..e6d18ba9a3 100644 --- a/examples/advanced/kaplan-meier-he/requirements.txt +++ b/examples/advanced/kaplan-meier-he/requirements.txt @@ -1,3 +1,3 @@ lifelines tenseal -scikit-survival \ No newline at end of file +scikit-survival diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index 7da0e1d8df..e7cfd3a7fd 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -150,8 +150,11 @@ def _build_shareable(self, data: FLModel = None) -> Shareable: data = self.model data_shareable: Shareable = FLModelUtils.to_shareable(data) - data_shareable.set_header(AppConstants.CURRENT_ROUND, self._current_round) - data_shareable.set_header(AppConstants.NUM_ROUNDS, self._num_rounds) + + if not data_shareable.get_header(AppConstants.CURRENT_ROUND): + data_shareable.set_header(AppConstants.CURRENT_ROUND, self._current_round) + if not data_shareable.get_header(AppConstants.NUM_ROUNDS): + data_shareable.set_header(AppConstants.NUM_ROUNDS, self._num_rounds) data_shareable.add_cookie(AppConstants.CONTRIBUTION_ROUND, self._current_round) return data_shareable From f67b6067ae8ba6c34bcb5d167facd66a54a568bb Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Mon, 8 Apr 2024 09:08:07 -0700 Subject: [PATCH 12/20] change to send_model_and_wait --- .../kaplan-meier-he/app/custom/kaplan_meier_wf.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py index 3b0170cb65..a14619e47e 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py @@ -46,9 +46,14 @@ def read_data(self, file_name: str): def start_fl_collect_max_idx(self): self.logger.info("send initial message to all sites to start FL \n") - model = FLModel(params={}, current_round=1, total_rounds=self.num_rounds) + model = FLModel( + params={}, + start_round=1, + current_round=1, + total_rounds=self.num_rounds + ) - results = self.send_model(data=model) + results = self.send_model_and_wait(data=model) return results def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): @@ -70,11 +75,12 @@ def distribute_max_idx_collect_enc_stats(self, result: int): model = FLModel( params={"max_idx_global": result}, params_type=ParamsType.FULL, + start_round=1, current_round=2, total_rounds=self.num_rounds, ) - results = self.send_model(data=model) + results = self.send_model_and_wait(data=model) return results def aggr_he_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): @@ -121,9 +127,10 @@ def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial) model = FLModel( params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, params_type=ParamsType.FULL, + start_round=1, current_round=3, total_rounds=self.num_rounds, ) - results = self.send_model(data=model) + results = self.send_model_and_wait(data=model) return results From adb244950acc1b39b2b10afe2a547869866be6c6 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Mon, 8 Apr 2024 09:17:03 -0700 Subject: [PATCH 13/20] format --- .../jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py index a14619e47e..1c9fdecaee 100644 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py +++ b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py @@ -46,12 +46,7 @@ def read_data(self, file_name: str): def start_fl_collect_max_idx(self): self.logger.info("send initial message to all sites to start FL \n") - model = FLModel( - params={}, - start_round=1, - current_round=1, - total_rounds=self.num_rounds - ) + model = FLModel(params={}, start_round=1, current_round=1, total_rounds=self.num_rounds) results = self.send_model_and_wait(data=model) return results From cc0a56e2e54c10b65c6739735a4f5c52940c2ef8 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Mon, 8 Apr 2024 09:30:20 -0700 Subject: [PATCH 14/20] updated readme --- examples/advanced/kaplan-meier-he/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 58cec6ebe2..1c54fcee3e 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -21,13 +21,13 @@ With these two settings, the server will have no access to any knowledge regardi The final Kaplan-Meier survival analysis will be performed locally on the global aggregated event list, recovered from global histograms. -## Simulated HE Analysis via FLARE Workflow Communicator API +## Simulated HE Analysis via FLARE Workflow Controller API -The Flare Workflow Communicator API provides the functionality of customized message payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme. +The Flare Workflow Controller API (`WFController`) provides the functionality of customized FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme. Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) does not support [simulator mode](https://nvflare.readthedocs.io/en/main/getting_started.html), the main reason is that the HE context information (specs and keys) needs to be provisioned before initializing the federated job. For the same reason, it is not straightforward for users to try different HE schemes beyond our existing support for [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py). -With the Flare Workflow Communicator API, such "proof of concept" experiment becomes easy (of course, secure provisioning is still the way to go for real-life federated applications). In this example, the federated analysis pipeline includes 3 rounds: +With the Flare Workflow Controller API, such "proof of concept" experiment becomes easy (of course, secure provisioning is still the way to go for real-life federated applications). In this example, the federated analysis pipeline includes 3 rounds: 1. Server send the simple start message without any payload. 2. Clients collect the information of the local maximum time bin number and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. 3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients. From 6549ea503dbcba64d1531817386684124a326725 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Mon, 8 Apr 2024 14:18:03 -0700 Subject: [PATCH 15/20] fix merge conflict --- nvflare/app_common/workflows/model_controller.py | 8 +++----- nvflare/app_common/workflows/wf_controller.py | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index 09f538e014..f891ef8ee0 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -90,11 +90,9 @@ def start_controller(self, fl_ctx: FLContext) -> None: def _build_shareable(self, data: FLModel = None) -> Shareable: data_shareable: Shareable = FLModelUtils.to_shareable(data) - if not data_shareable.get_header(AppConstants.CURRENT_ROUND): - data_shareable.set_header(AppConstants.CURRENT_ROUND, self._current_round) - if not data_shareable.get_header(AppConstants.NUM_ROUNDS): - data_shareable.set_header(AppConstants.NUM_ROUNDS, self._num_rounds) - data_shareable.add_cookie(AppConstants.CONTRIBUTION_ROUND, self._current_round) + data_shareable.add_cookie( + AppConstants.CONTRIBUTION_ROUND, data_shareable.get_header(AppConstants.CURRENT_ROUND) + ) return data_shareable diff --git a/nvflare/app_common/workflows/wf_controller.py b/nvflare/app_common/workflows/wf_controller.py index 4847e1a45c..668bd6e348 100644 --- a/nvflare/app_common/workflows/wf_controller.py +++ b/nvflare/app_common/workflows/wf_controller.py @@ -23,13 +23,13 @@ class WFController(ModelController, ABC): def __init__( self, *args, - persistor_id: str = "persistor", + persistor_id: str = "", **kwargs, ): """Workflow Controller API for FLModel-based ModelController. Args: - persistor_id (str, optional): ID of the persistor component. Defaults to "persistor". + persistor_id (str, optional): ID of the persistor component. Defaults to "". """ super().__init__(*args, persistor_id, **kwargs) From fd6270447b29aeb551d3bf30d5e102dcb195e6bc Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 9 Apr 2024 10:03:58 -0400 Subject: [PATCH 16/20] update readme --- examples/advanced/kaplan-meier-he/README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 1c54fcee3e..fa8a3035f0 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -2,7 +2,7 @@ This example illustrates two features: * How to perform Kaplan-Meirer survival analysis in federated setting securely via Homomorphic Encryption (HE). -* How to use the Flare Workflow Communicator API to contract a workflow to facilitate HE under simulator mode. +* How to use the Flare Workflow Controller API to contract a workflow to facilitate HE under simulator mode. ## Secure Multi-party Kaplan-Meier Analysis Kaplan-Meier survival analysis is a one-shot (non-iterative) analysis performed on a list of events and their corresponding time. In this example, we use [lifelines](https://zenodo.org/records/10456828) to perform this analysis. @@ -23,13 +23,15 @@ The final Kaplan-Meier survival analysis will be performed locally on the global ## Simulated HE Analysis via FLARE Workflow Controller API -The Flare Workflow Controller API (`WFController`) provides the functionality of customized FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme. +The Flare Workflow Controller API (`WFController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning. -Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) does not support [simulator mode](https://nvflare.readthedocs.io/en/main/getting_started.html), the main reason is that the HE context information (specs and keys) needs to be provisioned before initializing the federated job. For the same reason, it is not straightforward for users to try different HE schemes beyond our existing support for [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py). +Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate WFController's capability in supporting customized needs beyond the existing functionalities: +- different HE schemes (BFV) rather than CKKS +- different content at different rounds of federated learning, and only specific payload needs to be encrypted -With the Flare Workflow Controller API, such "proof of concept" experiment becomes easy (of course, secure provisioning is still the way to go for real-life federated applications). In this example, the federated analysis pipeline includes 3 rounds: +With the WFController API, such "proof of concept" experiment becomes easy. In this example, the federated analysis pipeline includes 3 rounds: 1. Server send the simple start message without any payload. -2. Clients collect the information of the local maximum time bin number and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. +2. Clients collect the information of the local maximum bin number (for event time) and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. 3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients. After Round 3, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information. From 479fb52421eb0e36aa6079d815b8c3e002e1fa18 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 9 Apr 2024 10:11:51 -0400 Subject: [PATCH 17/20] update readme --- examples/advanced/kaplan-meier-he/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index fa8a3035f0..ae93463dd3 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -25,7 +25,7 @@ The final Kaplan-Meier survival analysis will be performed locally on the global The Flare Workflow Controller API (`WFController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning. -Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate WFController's capability in supporting customized needs beyond the existing functionalities: +Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate WFController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for deep learning): - different HE schemes (BFV) rather than CKKS - different content at different rounds of federated learning, and only specific payload needs to be encrypted From f1df4833c93dbd85d8f34bfbdb344e4028f57923 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 9 Apr 2024 10:13:09 -0400 Subject: [PATCH 18/20] update readme --- examples/advanced/kaplan-meier-he/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index ae93463dd3..4c707c99e7 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -25,7 +25,7 @@ The final Kaplan-Meier survival analysis will be performed locally on the global The Flare Workflow Controller API (`WFController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning. -Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate WFController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for deep learning): +Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate WFController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models). - different HE schemes (BFV) rather than CKKS - different content at different rounds of federated learning, and only specific payload needs to be encrypted From 096a2025df9f4de5ba9be2bc4d0c58325ad998a3 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 9 Apr 2024 10:17:51 -0400 Subject: [PATCH 19/20] update readme --- examples/advanced/kaplan-meier-he/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 4c707c99e7..1317bd9379 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -43,11 +43,11 @@ python baseline_kaplan_meier.py ``` By default, this will generate a KM curve image `km_curve_baseline.png` under the current working directory. -Then we run a 5-client federated job with simulator, begin with splitting and generating the data files for each clients: +Then we run a 5-client federated job with simulator, begin with splitting and generating the data files for each client: ```commandline python utils/prepare_data.py --out_path "/tmp/flare/dataset/km_data" ``` -Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but for simulator, we use this step to distribute the HE context. +Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but in this study experimenting with BFV scheme, we use this step to distribute the HE context. ```commandline python utils/prepare_he_context.py --out_path "/tmp/flare/he_context" ``` From fafdd4363581e7857eea5f95c5b01dcb940951e0 Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Tue, 9 Apr 2024 09:05:59 -0700 Subject: [PATCH 20/20] move to job template --- examples/advanced/kaplan-meier-he/README.md | 16 +++++++++++++++- .../kaplan_meier_he}/config_fed_client.conf | 0 .../kaplan_meier_he}/config_fed_server.conf | 0 .../job_templates/kaplan_meier_he/info.conf | 5 +++++ .../job_templates/kaplan_meier_he/info.md | 11 +++++++++++ .../job_templates/kaplan_meier_he/meta.conf | 8 ++++++++ .../jobs/kaplan-meier-he/meta.conf | 7 ------- .../app/custom => src}/kaplan_meier_train.py | 0 .../app/custom => src}/kaplan_meier_wf.py | 0 9 files changed, 39 insertions(+), 8 deletions(-) rename examples/advanced/kaplan-meier-he/{jobs/kaplan-meier-he/app/config => job_templates/kaplan_meier_he}/config_fed_client.conf (100%) rename examples/advanced/kaplan-meier-he/{jobs/kaplan-meier-he/app/config => job_templates/kaplan_meier_he}/config_fed_server.conf (100%) create mode 100644 examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.conf create mode 100644 examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.md create mode 100644 examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/meta.conf delete mode 100644 examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf rename examples/advanced/kaplan-meier-he/{jobs/kaplan-meier-he/app/custom => src}/kaplan_meier_train.py (100%) rename examples/advanced/kaplan-meier-he/{jobs/kaplan-meier-he/app/custom => src}/kaplan_meier_wf.py (100%) diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index 1317bd9379..bd138c7583 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -1,7 +1,7 @@ # Secure Federated Kaplan-Meier Analysis via Homomorphic Encryption This example illustrates two features: -* How to perform Kaplan-Meirer survival analysis in federated setting securely via Homomorphic Encryption (HE). +* How to perform Kaplan-Meier survival analysis in federated setting securely via Homomorphic Encryption (HE). * How to use the Flare Workflow Controller API to contract a workflow to facilitate HE under simulator mode. ## Secure Multi-party Kaplan-Meier Analysis @@ -52,6 +52,20 @@ Then we prepare HE context for clients and server, note that this step is done b python utils/prepare_he_context.py --out_path "/tmp/flare/he_context" ``` +Next, we set the location of the job templates directory. +```commandline +nvflare config -jt ./job_templates +``` + +Then we can generate the job configuration from the `kaplan_meier_he` template: + +```commandline +N_CLIENTS=5 +nvflare job create -force -j "./jobs/kaplan-meier-he" -w "kaplan_meier_he" -sd "./src" \ +-f config_fed_client.conf app_script="kaplan_meier_train.py" app_config="--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" \ +-f config_fed_server.conf min_clients=${N_CLIENTS} he_context_path="/tmp/flare/he_context/he_context_server.txt" +``` + And we can run the federated job: ```commandline nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_client.conf similarity index 100% rename from examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_client.conf rename to examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_client.conf diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_server.conf similarity index 100% rename from examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/config/config_fed_server.conf rename to examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_server.conf diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.conf new file mode 100644 index 0000000000..a3579091d0 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.conf @@ -0,0 +1,5 @@ +{ + description = "Kaplan-Meier survival analysis with homomorphic encryption" + execution_api_type = "client_api" + controller_type = "server" +} \ No newline at end of file diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.md b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.md new file mode 100644 index 0000000000..4d74281bf3 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.md @@ -0,0 +1,11 @@ +# Job Template Information Card + +## kaplan_meier_he + name = "kaplan_meier_he" + description = "Kaplan-Meier survival analysis with homomorphic encryption" + class_name = "KM" + controller_type = "server" + executor_type = "launcher_executor" + contributor = "NVIDIA" + init_publish_date = "2024-04-09" + last_updated_date = "2024-04-09" diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/meta.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/meta.conf new file mode 100644 index 0000000000..624acb062d --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/meta.conf @@ -0,0 +1,8 @@ +name = "kaplan_meier_he" +resource_spec {} +min_clients = 2 +deploy_map { + app = [ + "@ALL" + ] +} diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf b/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf deleted file mode 100644 index a393deb907..0000000000 --- a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/meta.conf +++ /dev/null @@ -1,7 +0,0 @@ -{ - name = "kaplan-meier-he" - deploy_map { - app = ["@ALL"] - } - min_clients = 2 -} diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py b/examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py similarity index 100% rename from examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_train.py rename to examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py diff --git a/examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py similarity index 100% rename from examples/advanced/kaplan-meier-he/jobs/kaplan-meier-he/app/custom/kaplan_meier_wf.py rename to examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py