diff --git a/examples/advanced/kaplan-meier-he/README.md b/examples/advanced/kaplan-meier-he/README.md index bd138c7583..983c88e71e 100644 --- a/examples/advanced/kaplan-meier-he/README.md +++ b/examples/advanced/kaplan-meier-he/README.md @@ -1,7 +1,7 @@ -# Secure Federated Kaplan-Meier Analysis via Homomorphic Encryption +# Secure Federated Kaplan-Meier Analysis via Time-Binning and Homomorphic Encryption This example illustrates two features: -* How to perform Kaplan-Meier survival analysis in federated setting securely via Homomorphic Encryption (HE). +* How to perform Kaplan-Meier survival analysis in federated setting without and with secure features via time-binning and Homomorphic Encryption (HE). * How to use the Flare Workflow Controller API to contract a workflow to facilitate HE under simulator mode. ## Secure Multi-party Kaplan-Meier Analysis @@ -11,17 +11,33 @@ Essentially, the estimator needs to get access to the event list, and under the However, this poses a data security concern - by sharing the event list, the raw data can be exposed to external parties, which break the core value of federated analysis. -Therefore, we would like to design a secure mechanism to enable collaborative Kaplan-Meier analysis without the risk of exposing any raw information from a certain participant (at server end). This is achieved by two techniques: +Therefore, we would like to design a secure mechanism to enable collaborative Kaplan-Meier analysis without the risk of exposing the raw information from a participant, the targeted protection includes: +- Prevent clients from getting RAW data from each other; +- Prevent the aggregation server to access ANY information from submissions. -- Condense the raw event list to two histograms (one for observed events and the other for censored event) binned at certain interval (e.g. a week), such that events happened within the same bin from different participants can be aggregated and will not be distinguishable for the final aggregated histograms. -- The local histograms will be encrypted as one single vector before sending to server, and the global aggregation operation at server side will be performed entirely within encryption space with HE. +This is achieved by two techniques: +- Condense the raw event list to two histograms (one for observed events and the other for censored event) using binning at certain interval (e.g. a week), such that events happened within the same bin from different participants can be aggregated and will not be distinguishable for the final aggregated histograms. Note that coarser binning will lead to higher protection, but also lower resolution of the final Kaplan-Meier curve. +- The local histograms will be encrypted as one single vector before sending to server, and the global aggregation operation at server side will be performed entirely within encryption space with HE. This will not cause any information loss, while the server will perform aggregation within encryption space. With these two settings, the server will have no access to any knowledge regarding local submissions, and participants will only receive global aggregated histograms that will not contain distinguishable information regarding any individual participants (client number >= 3 - if only two participants, one can infer the other party's info by subtracting its own histograms). -The final Kaplan-Meier survival analysis will be performed locally on the global aggregated event list, recovered from global histograms. +The final Kaplan-Meier survival analysis will be performed locally on the global aggregated event list, recovered from decrypted global histograms. +## Baseline Kaplan-Meier Analysis +We first illustrate the baseline centralized Kaplan-Meier analysis without any secure features. We used veterans_lung_cancer dataset by +`from sksurv.datasets import load_veterans_lung_cancer`, and used `Status` as the event type and `Survival_in_days` as the event time to construct the event list. + +To run the baseline script, simply execute: +```commandline +python utils/baseline_kaplan_meier.py +``` +By default, this will generate a KM curve image `km_curve_baseline.png` under `/tmp` directory. The resutling KM curve is shown below: +![KM survival baseline](figs/km_curve_baseline.png) +Here, we show the survival curve for both daily (without binning) and weekly binning. The two curves aligns well with each other, while the weekly-binned curve has lower resolution. -## Simulated HE Analysis via FLARE Workflow Controller API + +## Federated Kaplan-Meier Analysis w/o and w/ HE +We make use of FLARE Workflow Controller API to implement the federated Kaplan-Meier analysis, both without and with HE. The Flare Workflow Controller API (`WFController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning. @@ -29,24 +45,25 @@ Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/ - different HE schemes (BFV) rather than CKKS - different content at different rounds of federated learning, and only specific payload needs to be encrypted -With the WFController API, such "proof of concept" experiment becomes easy. In this example, the federated analysis pipeline includes 3 rounds: +With the WFController API, such "proof of concept" experiment becomes easy. In this example, the federated analysis pipeline includes 2 rounds without HE, or 3 rounds with HE. + +For the federated analysis without HE, the detailed steps are as follows: +1. Server sends the simple start message without any payload. +2. Clients submit the local event histograms to server. Server aggregates the histograms with varying lengths by adding event counts of the same slot together, and sends the aggregated histograms back to clients. + +For the federated analysis with HE, we need to ensure proper HE aggregation using BFV, and the detailed steps are as follows: 1. Server send the simple start message without any payload. 2. Clients collect the information of the local maximum bin number (for event time) and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. 3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients. -After Round 3, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information. +After these rounds, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information. ## Run the job -We first run a baseline analysis with full event information: +First, we prepared data for a 5-client federated job. We split and generate the data files for each client with binning interval of 7 days. ```commandline -python baseline_kaplan_meier.py +python utils/prepare_data.py --site_num 5 --bin_days 7 --out_path "/tmp/flare/dataset/km_data" ``` -By default, this will generate a KM curve image `km_curve_baseline.png` under the current working directory. -Then we run a 5-client federated job with simulator, begin with splitting and generating the data files for each client: -```commandline -python utils/prepare_data.py --out_path "/tmp/flare/dataset/km_data" -``` Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but in this study experimenting with BFV scheme, we use this step to distribute the HE context. ```commandline python utils/prepare_he_context.py --out_path "/tmp/flare/he_context" @@ -57,23 +74,34 @@ Next, we set the location of the job templates directory. nvflare config -jt ./job_templates ``` -Then we can generate the job configuration from the `kaplan_meier_he` template: +Then we can generate the job configurations from the `kaplan_meier` template: +Both for the federated job without HE: +```commandline +N_CLIENTS=5 +nvflare job create -force -j "/tmp/flare/jobs/kaplan-meier" -w "kaplan_meier" -sd "./src" \ +-f config_fed_client.conf app_script="kaplan_meier_train.py" app_config="--data_root /tmp/flare/dataset/km_data" \ +-f config_fed_server.conf min_clients=${N_CLIENTS} +``` +and for the federated job with HE: ```commandline N_CLIENTS=5 -nvflare job create -force -j "./jobs/kaplan-meier-he" -w "kaplan_meier_he" -sd "./src" \ --f config_fed_client.conf app_script="kaplan_meier_train.py" app_config="--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" \ +nvflare job create -force -j "/tmp/flare/jobs/kaplan-meier-he" -w "kaplan_meier_he" -sd "./src" \ +-f config_fed_client.conf app_script="kaplan_meier_train_he.py" app_config="--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" \ -f config_fed_server.conf min_clients=${N_CLIENTS} he_context_path="/tmp/flare/he_context/he_context_server.txt" ``` And we can run the federated job: ```commandline -nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he +nvflare simulator -w /tmp/flare/workspace_km -n 5 -t 5 /tmp/flare/jobs/kaplan-meier ``` -By default, this will generate a KM curve image `km_curve_fl.png` under each client's directory. +```commandline +nvflare simulator -w /tmp/flare/workspace_km_he -n 5 -t 5 /tmp/flare/jobs/kaplan-meier-he +``` +By default, this will generate a KM curve image `km_curve_fl.png` and `km_curve_fl_he.png` under each client's directory. ## Display Result -By comparing the two curves, we can observe that the two are identical: -![KM survival baseline](figs/km_curve_baseline.png) +By comparing the two curves, we can observe that all curves are identical: ![KM survival fl](figs/km_curve_fl.png) +![KM survival fl_he](figs/km_curve_fl_he.png) diff --git a/examples/advanced/kaplan-meier-he/figs/km_curve_baseline.png b/examples/advanced/kaplan-meier-he/figs/km_curve_baseline.png index 34cdb9cabb..9ff1fcdb4c 100644 Binary files a/examples/advanced/kaplan-meier-he/figs/km_curve_baseline.png and b/examples/advanced/kaplan-meier-he/figs/km_curve_baseline.png differ diff --git a/examples/advanced/kaplan-meier-he/figs/km_curve_fl.png b/examples/advanced/kaplan-meier-he/figs/km_curve_fl.png index a4765d5654..df082d406a 100644 Binary files a/examples/advanced/kaplan-meier-he/figs/km_curve_fl.png and b/examples/advanced/kaplan-meier-he/figs/km_curve_fl.png differ diff --git a/examples/advanced/kaplan-meier-he/figs/km_curve_fl_he.png b/examples/advanced/kaplan-meier-he/figs/km_curve_fl_he.png new file mode 100644 index 0000000000..b1610c4183 Binary files /dev/null and b/examples/advanced/kaplan-meier-he/figs/km_curve_fl_he.png differ diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/config_fed_client.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/config_fed_client.conf new file mode 100644 index 0000000000..0bbb867b6e --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/config_fed_client.conf @@ -0,0 +1,116 @@ +{ + # version of the configuration + format_version = 2 + + # This is the application script which will be invoked. Client can replace this script with user's own training script. + app_script = "kaplan_meier_train.py" + + # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. + app_config = "" + + # Client Computing Executors. + executors = [ + { + # tasks the executors are defined to handle + tasks = ["train"] + + # This particular executor + executor { + + # This is an executor for Client API. The underline data exchange is using Pipe. + path = "nvflare.app_opt.pt.client_api_launcher_executor.ClientAPILauncherExecutor" + + args { + # launcher_id is used to locate the Launcher object in "components" + launcher_id = "launcher" + + # pipe_id is used to locate the Pipe object in "components" + pipe_id = "pipe" + + # Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. + # Please refer to the class docstring for all available arguments + heartbeat_timeout = 60 + + # format of the exchange parameters + params_exchange_format = "raw" + + # if the transfer_type is FULL, then it will be sent directly + # if the transfer_type is DIFF, then we will calculate the + # difference VS received parameters and send the difference + params_transfer_type = "FULL" + + # if train_with_evaluation is true, the executor will expect + # the custom code need to send back both the trained parameters and the evaluation metric + # otherwise only trained parameters are expected + train_with_evaluation = false + } + } + } + ], + + # this defined an array of task data filters. If provided, it will control the data from server controller to client executor + task_data_filters = [] + + # this defined an array of task result filters. If provided, it will control the result from client executor to server controller + task_result_filters = [] + + components = [ + { + # component id is "launcher" + id = "launcher" + + # the class path of this component + path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" + + args { + # the launcher will invoke the script + script = "python3 custom/{app_script} {app_config} " + # if launch_once is true, the SubprocessLauncher will launch once for the whole job + # if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server + launch_once = true + } + } + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + { + id = "metrics_pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{JOB_ID}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + }, + { + id = "metric_relay" + path = "nvflare.app_common.widgets.metric_relay.MetricRelay" + args { + pipe_id = "metrics_pipe" + event_type = "fed.analytix_log_stats" + # how fast should it read from the peer + read_interval = 0.1 + } + }, + { + # we use this component so the client api `flare.init()` can get required information + id = "config_preparer" + path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" + args { + component_ids = ["metric_relay"] + } + } + ] +} diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/config_fed_server.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/config_fed_server.conf new file mode 100644 index 0000000000..8af09b4d44 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/config_fed_server.conf @@ -0,0 +1,19 @@ +{ + # version of the configuration + format_version = 2 + task_data_filters =[] + task_result_filters = [] + + workflows = [ + { + id = "km" + path = "kaplan_meier_wf.KM" + args { + min_clients = 5 + } + } + ] + + components = [] + +} diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/info.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/info.conf new file mode 100644 index 0000000000..adb8758a7b --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/info.conf @@ -0,0 +1,5 @@ +{ + description = "Kaplan-Meier survival analysis" + execution_api_type = "client_api" + controller_type = "server" +} \ No newline at end of file diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/info.md b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/info.md new file mode 100644 index 0000000000..e101b39174 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/info.md @@ -0,0 +1,11 @@ +# Job Template Information Card + +## kaplan_meier + name = "kaplan_meier" + description = "Kaplan-Meier survival analysis" + class_name = "KM" + controller_type = "server" + executor_type = "launcher_executor" + contributor = "NVIDIA" + init_publish_date = "2024-04-09" + last_updated_date = "2024-04-30" diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/meta.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/meta.conf new file mode 100644 index 0000000000..f7a133d56d --- /dev/null +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier/meta.conf @@ -0,0 +1,8 @@ +name = "kaplan_meier" +resource_spec {} +min_clients = 2 +deploy_map { + app = [ + "@ALL" + ] +} diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_client.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_client.conf index 0704590617..ae7e3b08ab 100644 --- a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_client.conf +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_client.conf @@ -3,10 +3,10 @@ format_version = 2 # This is the application script which will be invoked. Client can replace this script with user's own training script. - app_script = "kaplan_meier_train.py" + app_script = "kaplan_meier_train_he.py" # Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. - app_config = "--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" + app_config = "" # Client Computing Executors. executors = [ diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_server.conf b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_server.conf index 2589c856bd..5747e68636 100644 --- a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_server.conf +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_server.conf @@ -7,10 +7,10 @@ workflows = [ { id = "km" - path = "kaplan_meier_wf.KM" + path = "kaplan_meier_wf_he.KM" args { - min_clients = 5 - he_context_path = "/tmp/flare/he_context/he_context_server.txt" + min_clients = 5 + he_context_path = "/tmp/flare/he_context/he_context_server.txt" } } ] diff --git a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.md b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.md index 4d74281bf3..f1f42d03ba 100644 --- a/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.md +++ b/examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.md @@ -8,4 +8,4 @@ executor_type = "launcher_executor" contributor = "NVIDIA" init_publish_date = "2024-04-09" - last_updated_date = "2024-04-09" + last_updated_date = "2024-04-30" diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py b/examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py index 401c26aaf1..d8d7e55d28 100644 --- a/examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py +++ b/examples/advanced/kaplan-meier-he/src/kaplan_meier_train.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,14 +13,12 @@ # limitations under the License. import argparse -import base64 import json import os import matplotlib.pyplot as plt import numpy as np import pandas as pd -import tenseal as ts from lifelines import KaplanMeierFitter from lifelines.utils import survival_table_from_events @@ -30,12 +28,6 @@ # Client code -def read_data(file_name: str): - with open(file_name, "rb") as f: - data = f.read() - return base64.b64decode(data) - - def details_save(kmf): # Get the survival function at all observed time points survival_function_at_all_times = kmf.survival_function_ @@ -63,7 +55,7 @@ def details_save(kmf): def plot_and_save(kmf): # Plot and save the Kaplan-Meier survival curve plt.figure() - plt.title("Federated HE") + plt.title("Federated") kmf.plot_survival_function() plt.ylim(0, 1) plt.ylabel("prob") @@ -78,7 +70,6 @@ def plot_and_save(kmf): def main(): parser = argparse.ArgumentParser(description="KM analysis") parser.add_argument("--data_root", type=str, help="Root path for data files") - parser.add_argument("--he_context_path", type=str, help="Path for the HE context file") args = parser.parse_args() flare.init() @@ -92,11 +83,6 @@ def main(): event_local = data["event"] time_local = data["time"] - # HE context - # In real-life application, HE context is prepared by secure provisioning - he_context_serial = read_data(args.he_context_path) - he_context = ts.context_from(he_context_serial) - while flare.is_running(): # receives global message from NVFlare global_msg = flare.receive() @@ -105,63 +91,34 @@ def main(): if curr_round == 1: # First round: - # Empty payload from server, send max index back - # Condense local data to histogram + # Empty payload from server, send local histogram + # Convert local data to histogram event_table = survival_table_from_events(time_local, event_local) hist_idx = event_table.index.values.astype(int) - # Get the max index to be synced globally - max_hist_idx = max(hist_idx) - - # Send max to server - print(f"send max hist index for site = {flare.get_site_name()}") - model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) - flare.send(model) - - elif curr_round == 2: - # Second round, get global max index - # Organize local histogram and encrypt - max_idx_global = global_msg.params["max_idx_global"] - print("Global Max Idx") - print(max_idx_global) - # Convert local table to uniform histogram hist_obs = {} hist_cen = {} - for idx in range(max_idx_global): + for idx in range(max(hist_idx)): hist_obs[idx] = 0 hist_cen[idx] = 0 - # assign values + # Assign values idx = event_table.index.values.astype(int) observed = event_table["observed"].to_numpy() censored = event_table["censored"].to_numpy() for i in range(len(idx)): hist_obs[idx[i]] = observed[i] hist_cen[idx[i]] = censored[i] - # Encrypt with tenseal using BFV scheme since observations are integers - hist_obs_he = ts.bfv_vector(he_context, list(hist_obs.values())) - hist_cen_he = ts.bfv_vector(he_context, list(hist_cen.values())) - # Serialize for transmission - hist_obs_he_serial = hist_obs_he.serialize() - hist_cen_he_serial = hist_cen_he.serialize() - # Send encrypted histograms to server - response = FLModel( - params={"hist_obs": hist_obs_he_serial, "hist_cen": hist_cen_he_serial}, params_type=ParamsType.FULL - ) + # Send histograms to server + response = FLModel(params={"hist_obs": hist_obs, "hist_cen": hist_cen}, params_type=ParamsType.FULL) flare.send(response) - elif curr_round == 3: + elif curr_round == 2: # Get global histograms - hist_obs_global_serial = global_msg.params["hist_obs_global"] - hist_cen_global_serial = global_msg.params["hist_cen_global"] - # Deserialize - hist_obs_global = ts.bfv_vector_from(he_context, hist_obs_global_serial) - hist_cen_global = ts.bfv_vector_from(he_context, hist_cen_global_serial) - # Decrypt - hist_obs_global = hist_obs_global.decrypt() - hist_cen_global = hist_cen_global.decrypt() + hist_obs_global = global_msg.params["hist_obs_global"] + hist_cen_global = global_msg.params["hist_cen_global"] # Unfold histogram to event list time_unfold = [] event_unfold = [] - for i in range(max_idx_global): + for i in hist_obs_global.keys(): for j in range(hist_obs_global[i]): time_unfold.append(i) event_unfold.append(True) diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_train_he.py b/examples/advanced/kaplan-meier-he/src/kaplan_meier_train_he.py new file mode 100644 index 0000000000..1ff9c69dbb --- /dev/null +++ b/examples/advanced/kaplan-meier-he/src/kaplan_meier_train_he.py @@ -0,0 +1,195 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import base64 +import json +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import tenseal as ts +from lifelines import KaplanMeierFitter +from lifelines.utils import survival_table_from_events + +# (1) import nvflare client API +import nvflare.client as flare +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType + + +# Client code +def read_data(file_name: str): + with open(file_name, "rb") as f: + data = f.read() + return base64.b64decode(data) + + +def details_save(kmf): + # Get the survival function at all observed time points + survival_function_at_all_times = kmf.survival_function_ + # Get the timeline (time points) + timeline = survival_function_at_all_times.index.values + # Get the KM estimate + km_estimate = survival_function_at_all_times["KM_estimate"].values + # Get the event count at each time point + event_count = kmf.event_table.iloc[:, 0].values # Assuming the first column is the observed events + # Get the survival rate at each time point (using the 1st column of the survival function) + survival_rate = 1 - survival_function_at_all_times.iloc[:, 0].values + # Return the results + results = { + "timeline": timeline.tolist(), + "km_estimate": km_estimate.tolist(), + "event_count": event_count.tolist(), + "survival_rate": survival_rate.tolist(), + } + file_path = os.path.join(os.getcwd(), "km_global.json") + print(f"save the details of KM analysis result to {file_path} \n") + with open(file_path, "w") as json_file: + json.dump(results, json_file, indent=4) + + +def plot_and_save(kmf): + # Plot and save the Kaplan-Meier survival curve + plt.figure() + plt.title("Federated HE") + kmf.plot_survival_function() + plt.ylim(0, 1) + plt.ylabel("prob") + plt.xlabel("time") + plt.legend("", frameon=False) + plt.tight_layout() + file_path = os.path.join(os.getcwd(), "km_curve_fl_he.png") + print(f"save the curve plot to {file_path} \n") + plt.savefig(file_path) + + +def main(): + parser = argparse.ArgumentParser(description="KM analysis") + parser.add_argument("--data_root", type=str, help="Root path for data files") + parser.add_argument("--he_context_path", type=str, help="Path for the HE context file") + args = parser.parse_args() + + flare.init() + + site_name = flare.get_site_name() + print(f"Kaplan-meier analysis for {site_name}") + + # get local data + data_path = os.path.join(args.data_root, site_name + ".csv") + data = pd.read_csv(data_path) + event_local = data["event"] + time_local = data["time"] + + # HE context + # In real-life application, HE context is prepared by secure provisioning + he_context_serial = read_data(args.he_context_path) + he_context = ts.context_from(he_context_serial) + + while flare.is_running(): + # receives global message from NVFlare + global_msg = flare.receive() + curr_round = global_msg.current_round + print(f"current_round={curr_round}") + + if curr_round == 1: + # First round: + # Empty payload from server, send max index back + # Condense local data to histogram + event_table = survival_table_from_events(time_local, event_local) + hist_idx = event_table.index.values.astype(int) + # Get the max index to be synced globally + max_hist_idx = max(hist_idx) + + # Send max to server + print(f"send max hist index for site = {flare.get_site_name()}") + model = FLModel(params={"max_idx": max_hist_idx}, params_type=ParamsType.FULL) + flare.send(model) + + elif curr_round == 2: + # Second round, get global max index + # Organize local histogram and encrypt + max_idx_global = global_msg.params["max_idx_global"] + print("Global Max Idx") + print(max_idx_global) + # Convert local table to uniform histogram + hist_obs = {} + hist_cen = {} + for idx in range(max_idx_global): + hist_obs[idx] = 0 + hist_cen[idx] = 0 + # assign values + idx = event_table.index.values.astype(int) + observed = event_table["observed"].to_numpy() + censored = event_table["censored"].to_numpy() + for i in range(len(idx)): + hist_obs[idx[i]] = observed[i] + hist_cen[idx[i]] = censored[i] + # Encrypt with tenseal using BFV scheme since observations are integers + hist_obs_he = ts.bfv_vector(he_context, list(hist_obs.values())) + hist_cen_he = ts.bfv_vector(he_context, list(hist_cen.values())) + # Serialize for transmission + hist_obs_he_serial = hist_obs_he.serialize() + hist_cen_he_serial = hist_cen_he.serialize() + # Send encrypted histograms to server + response = FLModel( + params={"hist_obs": hist_obs_he_serial, "hist_cen": hist_cen_he_serial}, params_type=ParamsType.FULL + ) + flare.send(response) + + elif curr_round == 3: + # Get global histograms + hist_obs_global_serial = global_msg.params["hist_obs_global"] + hist_cen_global_serial = global_msg.params["hist_cen_global"] + # Deserialize + hist_obs_global = ts.bfv_vector_from(he_context, hist_obs_global_serial) + hist_cen_global = ts.bfv_vector_from(he_context, hist_cen_global_serial) + # Decrypt + hist_obs_global = hist_obs_global.decrypt() + hist_cen_global = hist_cen_global.decrypt() + # Unfold histogram to event list + time_unfold = [] + event_unfold = [] + for i in range(max_idx_global): + for j in range(hist_obs_global[i]): + time_unfold.append(i) + event_unfold.append(True) + for k in range(hist_cen_global[i]): + time_unfold.append(i) + event_unfold.append(False) + time_unfold = np.array(time_unfold) + event_unfold = np.array(event_unfold) + + # Perform Kaplan-Meier analysis on global aggregated information + # Create a Kaplan-Meier estimator + kmf = KaplanMeierFitter() + + # Fit the model + kmf.fit(durations=time_unfold, event_observed=event_unfold) + + # Plot and save the KM curve + plot_and_save(kmf) + + # Save details of the KM result to a json file + details_save(kmf) + + # Send a simple response to server + response = FLModel(params={}, params_type=ParamsType.FULL) + flare.send(response) + + print(f"finish send for {site_name}, complete") + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py b/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py index 1c9fdecaee..c702b2dad0 100644 --- a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py +++ b/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,118 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. -import base64 import logging from typing import Dict -import tenseal as ts - from nvflare.app_common.abstract.fl_model import FLModel, ParamsType from nvflare.app_common.workflows.wf_controller import WFController -# Controller Workflow - +# Controller Workflow class KM(WFController): - def __init__(self, min_clients: int, he_context_path: str): + def __init__(self, min_clients: int): super(KM, self).__init__() self.logger = logging.getLogger(self.__class__.__name__) self.min_clients = min_clients - self.he_context_path = he_context_path - self.num_rounds = 3 + self.num_rounds = 2 def run(self): - max_idx_results = self.start_fl_collect_max_idx() - global_res = self.aggr_max_idx(max_idx_results) - enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res) - hist_obs_global, hist_cen_global = self.aggr_he_hist(enc_hist_results) + hist_local = self.start_fl_collect_hist() + hist_obs_global, hist_cen_global = self.aggr_hist(hist_local) _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) - def read_data(self, file_name: str): - with open(file_name, "rb") as f: - data = f.read() - return base64.b64decode(data) - - def start_fl_collect_max_idx(self): + def start_fl_collect_hist(self): self.logger.info("send initial message to all sites to start FL \n") model = FLModel(params={}, start_round=1, current_round=1, total_rounds=self.num_rounds) results = self.send_model_and_wait(data=model) return results - def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): - self.logger.info("aggregate max histogram index \n") + def aggr_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate histogram \n") if not sag_result: raise RuntimeError("input is None or empty") - max_idx_global = [] + hist_idx_max = 0 for fl_model in sag_result: - max_idx = fl_model.params["max_idx"] - max_idx_global.append(max_idx) - # actual time point as index, so plus 1 for storage - return max(max_idx_global) + 1 + hist = fl_model.params["hist_obs"] + if hist_idx_max < max(hist.keys()): + hist_idx_max = max(hist.keys()) + hist_idx_max += 1 - def distribute_max_idx_collect_enc_stats(self, result: int): - self.logger.info("send global max_index to all sites \n") + hist_obs_global = {} + hist_cen_global = {} + for idx in range(hist_idx_max + 1): + hist_obs_global[idx] = 0 + hist_cen_global[idx] = 0 - model = FLModel( - params={"max_idx_global": result}, - params_type=ParamsType.FULL, - start_round=1, - current_round=2, - total_rounds=self.num_rounds, - ) - - results = self.send_model_and_wait(data=model) - return results - - def aggr_he_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): - self.logger.info("aggregate histogram within HE \n") - - # Load HE context - he_context_serial = self.read_data(self.he_context_path) - he_context = ts.context_from(he_context_serial) + for fl_model in sag_result: + hist_obs = fl_model.params["hist_obs"] + hist_cen = fl_model.params["hist_cen"] + for i in hist_obs.keys(): + hist_obs_global[i] += hist_obs[i] + for i in hist_cen.keys(): + hist_cen_global[i] += hist_cen[i] - if not sag_result: - raise RuntimeError("input is None or empty") + return hist_obs_global, hist_cen_global - hist_obs_global = None - hist_cen_global = None - for fl_model in sag_result: - site = fl_model.meta.get("client_name", None) - hist_obs_he_serial = fl_model.params["hist_obs"] - hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial) - hist_cen_he_serial = fl_model.params["hist_cen"] - hist_cen_he = ts.bfv_vector_from(he_context, hist_cen_he_serial) - - if not hist_obs_global: - print(f"assign global hist with result from {site}") - hist_obs_global = hist_obs_he - else: - print(f"add to global hist with result from {site}") - hist_obs_global += hist_obs_he - - if not hist_cen_global: - print(f"assign global hist with result from {site}") - hist_cen_global = hist_cen_he - else: - print(f"add to global hist with result from {site}") - hist_cen_global += hist_cen_he - - # return the two accumulated vectors, serialized for transmission - hist_obs_global_serial = hist_obs_global.serialize() - hist_cen_global_serial = hist_cen_global.serialize() - return hist_obs_global_serial, hist_cen_global_serial - - def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial): + def distribute_global_hist(self, hist_obs_global, hist_cen_global): self.logger.info("send global accumulated histograms within HE to all sites \n") model = FLModel( - params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, + params={"hist_obs_global": hist_obs_global, "hist_cen_global": hist_cen_global}, params_type=ParamsType.FULL, start_round=1, - current_round=3, + current_round=2, total_rounds=self.num_rounds, ) diff --git a/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py b/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py new file mode 100644 index 0000000000..5cd1a86012 --- /dev/null +++ b/examples/advanced/kaplan-meier-he/src/kaplan_meier_wf_he.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import logging +from typing import Dict + +import tenseal as ts + +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType +from nvflare.app_common.workflows.wf_controller import WFController + +# Controller Workflow + + +class KM(WFController): + def __init__(self, min_clients: int, he_context_path: str): + super(KM, self).__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.min_clients = min_clients + self.he_context_path = he_context_path + self.num_rounds = 3 + + def run(self): + max_idx_results = self.start_fl_collect_max_idx() + global_res = self.aggr_max_idx(max_idx_results) + enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res) + hist_obs_global, hist_cen_global = self.aggr_he_hist(enc_hist_results) + _ = self.distribute_global_hist(hist_obs_global, hist_cen_global) + + def read_data(self, file_name: str): + with open(file_name, "rb") as f: + data = f.read() + return base64.b64decode(data) + + def start_fl_collect_max_idx(self): + self.logger.info("send initial message to all sites to start FL \n") + model = FLModel(params={}, start_round=1, current_round=1, total_rounds=self.num_rounds) + + results = self.send_model_and_wait(data=model) + return results + + def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate max histogram index \n") + + if not sag_result: + raise RuntimeError("input is None or empty") + + max_idx_global = [] + for fl_model in sag_result: + max_idx = fl_model.params["max_idx"] + max_idx_global.append(max_idx) + # actual time point as index, so plus 1 for storage + return max(max_idx_global) + 1 + + def distribute_max_idx_collect_enc_stats(self, result: int): + self.logger.info("send global max_index to all sites \n") + + model = FLModel( + params={"max_idx_global": result}, + params_type=ParamsType.FULL, + start_round=1, + current_round=2, + total_rounds=self.num_rounds, + ) + + results = self.send_model_and_wait(data=model) + return results + + def aggr_he_hist(self, sag_result: Dict[str, Dict[str, FLModel]]): + self.logger.info("aggregate histogram within HE \n") + + # Load HE context + he_context_serial = self.read_data(self.he_context_path) + he_context = ts.context_from(he_context_serial) + + if not sag_result: + raise RuntimeError("input is None or empty") + + hist_obs_global = None + hist_cen_global = None + for fl_model in sag_result: + site = fl_model.meta.get("client_name", None) + hist_obs_he_serial = fl_model.params["hist_obs"] + hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial) + hist_cen_he_serial = fl_model.params["hist_cen"] + hist_cen_he = ts.bfv_vector_from(he_context, hist_cen_he_serial) + + if not hist_obs_global: + print(f"assign global hist with result from {site}") + hist_obs_global = hist_obs_he + else: + print(f"add to global hist with result from {site}") + hist_obs_global += hist_obs_he + + if not hist_cen_global: + print(f"assign global hist with result from {site}") + hist_cen_global = hist_cen_he + else: + print(f"add to global hist with result from {site}") + hist_cen_global += hist_cen_he + + # return the two accumulated vectors, serialized for transmission + hist_obs_global_serial = hist_obs_global.serialize() + hist_cen_global_serial = hist_cen_global.serialize() + return hist_obs_global_serial, hist_cen_global_serial + + def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial): + self.logger.info("send global accumulated histograms within HE to all sites \n") + + model = FLModel( + params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial}, + params_type=ParamsType.FULL, + start_round=1, + current_round=3, + total_rounds=self.num_rounds, + ) + + results = self.send_model_and_wait(data=model) + return results diff --git a/examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py b/examples/advanced/kaplan-meier-he/utils/baseline_kaplan_meier.py similarity index 74% rename from examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py rename to examples/advanced/kaplan-meier-he/utils/baseline_kaplan_meier.py index 79e7b8052f..86292c2e51 100644 --- a/examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py +++ b/examples/advanced/kaplan-meier-he/utils/baseline_kaplan_meier.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ def args_parser(): parser.add_argument( "--output_curve_path", type=str, - default="./km_curve_baseline.png", + default="/tmp/km_curve_baseline.png", help="save path for the output curve", ) return parser @@ -34,11 +34,10 @@ def args_parser(): def prepare_data(bin_days: int = 7): data_x, data_y = load_veterans_lung_cancer() total_data_num = data_x.shape[0] - print(f"Total data count: {total_data_num}") event = data_y["Status"] time = data_y["Survival_in_days"] # Categorize data to a bin, default is a week (7 days) - time = np.ceil(time / bin_days).astype(int) + time = np.ceil(time / bin_days).astype(int) * bin_days return event, time @@ -49,22 +48,33 @@ def main(): # Set parameters output_curve_path = args.output_curve_path - # Generate data - event, time = prepare_data() + # Set plot + plt.figure() + plt.title("Baseline") # Fit and plot Kaplan Meier curve with lifelines + + # Generate data with binning + event, time = prepare_data(bin_days=7) kmf = KaplanMeierFitter() # Fit the survival data kmf.fit(time, event) # Plot and save the Kaplan-Meier survival curve - plt.figure() - plt.title("Baseline") - kmf.plot_survival_function() + kmf.plot_survival_function(label="Binned Weekly") + + # Generate data without binning + event, time = prepare_data(bin_days=1) + kmf = KaplanMeierFitter() + # Fit the survival data + kmf.fit(time, event) + # Plot and save the Kaplan-Meier survival curve + kmf.plot_survival_function(label="No binning - Daily") + plt.ylim(0, 1) plt.ylabel("prob") plt.xlabel("time") - plt.legend("", frameon=False) plt.tight_layout() + plt.legend() plt.savefig(output_curve_path) diff --git a/examples/advanced/kaplan-meier-he/utils/prepare_data.py b/examples/advanced/kaplan-meier-he/utils/prepare_data.py index 66684a1b4b..0517ad6274 100644 --- a/examples/advanced/kaplan-meier-he/utils/prepare_data.py +++ b/examples/advanced/kaplan-meier-he/utils/prepare_data.py @@ -31,11 +31,12 @@ def data_split_args_parser(): default="site-", help="Site name prefix, default is site-", ) + parser.add_argument("--bin_days", type=int, default=1, help="Bin days for categorizing data") parser.add_argument("--out_path", type=str, help="Output root path for split data files") return parser -def prepare_data(data, site_num, bin_days: int = 7): +def prepare_data(data, site_num, bin_days): # Get total data count total_data_num = data.shape[0] print(f"Total data count: {total_data_num}") @@ -43,7 +44,7 @@ def prepare_data(data, site_num, bin_days: int = 7): event = data["Status"] time = data["Survival_in_days"] # Categorize data to a bin, default is a week (7 days) - time = np.ceil(time / bin_days).astype(int) + time = np.ceil(time / bin_days).astype(int) * bin_days # Shuffle data idx = np.random.permutation(total_data_num) # Split data to clients @@ -68,7 +69,7 @@ def main(): _, data = load_veterans_lung_cancer() # Prepare data - event_clients, time_clients = prepare_data(data=data, site_num=args.site_num) + event_clients, time_clients = prepare_data(data=data, site_num=args.site_num, bin_days=args.bin_days) # Save data to csv files if not os.path.exists(args.out_path):