Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update KM example, add 2-stage solution without HE #2541

Merged
merged 7 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions examples/advanced/kaplan-meier-he/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Secure Federated Kaplan-Meier Analysis via Homomorphic Encryption
# Secure Federated Kaplan-Meier Analysis via Time-Binning and Homomorphic Encryption

This example illustrates two features:
* How to perform Kaplan-Meier survival analysis in federated setting securely via Homomorphic Encryption (HE).
* How to perform Kaplan-Meier survival analysis in federated setting without and with secure features via time-binning and Homomorphic Encryption (HE).
* How to use the Flare Workflow Controller API to contract a workflow to facilitate HE under simulator mode.

## Secure Multi-party Kaplan-Meier Analysis
Expand All @@ -11,42 +11,59 @@ Essentially, the estimator needs to get access to the event list, and under the

However, this poses a data security concern - by sharing the event list, the raw data can be exposed to external parties, which break the core value of federated analysis.

Therefore, we would like to design a secure mechanism to enable collaborative Kaplan-Meier analysis without the risk of exposing any raw information from a certain participant (at server end). This is achieved by two techniques:
Therefore, we would like to design a secure mechanism to enable collaborative Kaplan-Meier analysis without the risk of exposing the raw information from a participant, the targeted protection includes:
- Prevent clients from getting RAW data from each other;
- Prevent the aggregation server to access ANY information from submissions.

- Condense the raw event list to two histograms (one for observed events and the other for censored event) binned at certain interval (e.g. a week), such that events happened within the same bin from different participants can be aggregated and will not be distinguishable for the final aggregated histograms.
- The local histograms will be encrypted as one single vector before sending to server, and the global aggregation operation at server side will be performed entirely within encryption space with HE.
This is achieved by two techniques:
- Condense the raw event list to two histograms (one for observed events and the other for censored event) using binning at certain interval (e.g. a week), such that events happened within the same bin from different participants can be aggregated and will not be distinguishable for the final aggregated histograms. Note that coarser binning will lead to higher protection, but also lower resolution of the final Kaplan-Meier curve.
- The local histograms will be encrypted as one single vector before sending to server, and the global aggregation operation at server side will be performed entirely within encryption space with HE. This will not cause any information loss, while the server will perform aggregation within encryption space.

With these two settings, the server will have no access to any knowledge regarding local submissions, and participants will only receive global aggregated histograms that will not contain distinguishable information regarding any individual participants (client number >= 3 - if only two participants, one can infer the other party's info by subtracting its own histograms).

The final Kaplan-Meier survival analysis will be performed locally on the global aggregated event list, recovered from global histograms.
The final Kaplan-Meier survival analysis will be performed locally on the global aggregated event list, recovered from decrypted global histograms.

## Baseline Kaplan-Meier Analysis
We first illustrate the baseline centralized Kaplan-Meier analysis without any secure features. We used veterans_lung_cancer dataset by
`from sksurv.datasets import load_veterans_lung_cancer`, and used `Status` as the event type and `Survival_in_days` as the event time to construct the event list.

To run the baseline script, simply execute:
```commandline
python baseline_kaplan_meier.py
```
By default, this will generate a KM curve image `km_curve_baseline.png` under `/tmp` directory. The resutling KM curve is shown below:
![KM survival baseline](figs/km_curve_baseline.png)
Here, we show the survival curve for both daily (without binning) and weekly binning. The two curves aligns well with each other, while the weekly-binned curve has lower resolution.

## Simulated HE Analysis via FLARE Workflow Controller API

## Federated Kaplan-Meier Analysis w/o and w/ HE
We make use of FLARE Workflow Controller API to implement the federated Kaplan-Meier analysis, both without and with HE.

The Flare Workflow Controller API (`WFController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning.

Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate WFController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models).
- different HE schemes (BFV) rather than CKKS
- different content at different rounds of federated learning, and only specific payload needs to be encrypted

With the WFController API, such "proof of concept" experiment becomes easy. In this example, the federated analysis pipeline includes 3 rounds:
With the WFController API, such "proof of concept" experiment becomes easy. In this example, the federated analysis pipeline includes 2 rounds without HE, or 3 rounds with HE.

For the federated analysis without HE, the detailed steps are as follows:
1. Server sends the simple start message without any payload.
2. Clients submit the local event histograms to server. Server aggregates the histograms with varying lengths by adding event counts of the same slot together, and sends the aggregated histograms back to clients.

For the federated analysis with HE, we need to ensure proper HE aggregation using BFV, and the detailed steps are as follows:
1. Server send the simple start message without any payload.
2. Clients collect the information of the local maximum bin number (for event time) and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable.
3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients.

After Round 3, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information.
After these rounds, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information.

## Run the job
We first run a baseline analysis with full event information:
First, we prepared data for a 5-client federated job. We split and generate the data files for each client with binning interval of 7 days.
```commandline
python baseline_kaplan_meier.py
python utils/prepare_data.py --site_num 5 --bin_days 7 --out_path "/tmp/flare/dataset/km_data"
```
By default, this will generate a KM curve image `km_curve_baseline.png` under the current working directory.

Then we run a 5-client federated job with simulator, begin with splitting and generating the data files for each client:
```commandline
python utils/prepare_data.py --out_path "/tmp/flare/dataset/km_data"
```
Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but in this study experimenting with BFV scheme, we use this step to distribute the HE context.
```commandline
python utils/prepare_he_context.py --out_path "/tmp/flare/he_context"
Expand All @@ -57,23 +74,34 @@ Next, we set the location of the job templates directory.
nvflare config -jt ./job_templates
```

Then we can generate the job configuration from the `kaplan_meier_he` template:
Then we can generate the job configurations from the `kaplan_meier` template:

ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
Both for the federated job without HE:
```commandline
N_CLIENTS=5
nvflare job create -force -j "/tmp/flare/jobs/kaplan-meier" -w "kaplan_meier" -sd "./src" \
-f config_fed_client.conf app_script="kaplan_meier_train.py" app_config="--data_root /tmp/flare/dataset/km_data" \
-f config_fed_server.conf min_clients=${N_CLIENTS}
```
and for the federated job with HE:
```commandline
N_CLIENTS=5
nvflare job create -force -j "./jobs/kaplan-meier-he" -w "kaplan_meier_he" -sd "./src" \
-f config_fed_client.conf app_script="kaplan_meier_train.py" app_config="--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" \
nvflare job create -force -j "/tmp/flare/jobs/kaplan-meier-he" -w "kaplan_meier_he" -sd "./src" \
-f config_fed_client.conf app_script="kaplan_meier_train_he.py" app_config="--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" \
-f config_fed_server.conf min_clients=${N_CLIENTS} he_context_path="/tmp/flare/he_context/he_context_server.txt"
```

And we can run the federated job:
```commandline
nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he
nvflare simulator -w /tmp/flare/workspace_km -n 5 -t 5 /tmp/flare/jobs/kaplan-meier
```
```commandline
nvflare simulator -w /tmp/flare/workspace_km_he -n 5 -t 5 /tmp/flare/jobs/kaplan-meier-he
```
By default, this will generate a KM curve image `km_curve_fl.png` under each client's directory.
By default, this will generate a KM curve image `km_curve_fl.png` and `km_curve_fl_he.png` under each client's directory.

## Display Result

By comparing the two curves, we can observe that the two are identical:
![KM survival baseline](figs/km_curve_baseline.png)
By comparing the two curves, we can observe that all curves are identical:
![KM survival fl](figs/km_curve_fl.png)
![KM survival fl_he](figs/km_curve_fl_he.png)
30 changes: 20 additions & 10 deletions examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -25,7 +25,7 @@ def args_parser():
parser.add_argument(
"--output_curve_path",
type=str,
default="./km_curve_baseline.png",
default="/tmp/km_curve_baseline.png",
help="save path for the output curve",
)
return parser
Expand All @@ -34,11 +34,10 @@ def args_parser():
def prepare_data(bin_days: int = 7):
data_x, data_y = load_veterans_lung_cancer()
total_data_num = data_x.shape[0]
print(f"Total data count: {total_data_num}")
event = data_y["Status"]
time = data_y["Survival_in_days"]
# Categorize data to a bin, default is a week (7 days)
time = np.ceil(time / bin_days).astype(int)
time = np.ceil(time / bin_days).astype(int) * bin_days
return event, time


Expand All @@ -49,22 +48,33 @@ def main():
# Set parameters
output_curve_path = args.output_curve_path

# Generate data
event, time = prepare_data()
# Set plot
plt.figure()
plt.title("Baseline")

# Fit and plot Kaplan Meier curve with lifelines

# Generate data with binning
event, time = prepare_data(bin_days=7)
kmf = KaplanMeierFitter()
# Fit the survival data
kmf.fit(time, event)
# Plot and save the Kaplan-Meier survival curve
plt.figure()
plt.title("Baseline")
kmf.plot_survival_function()
kmf.plot_survival_function(label="Binned Weekly")

# Generate data without binning
event, time = prepare_data(bin_days=1)
kmf = KaplanMeierFitter()
# Fit the survival data
kmf.fit(time, event)
# Plot and save the Kaplan-Meier survival curve
kmf.plot_survival_function(label="No binning - Daily")

plt.ylim(0, 1)
plt.ylabel("prob")
plt.xlabel("time")
plt.legend("", frameon=False)
plt.tight_layout()
plt.legend()
plt.savefig(output_curve_path)


Expand Down
Binary file modified examples/advanced/kaplan-meier-he/figs/km_curve_baseline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/advanced/kaplan-meier-he/figs/km_curve_fl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
{
# version of the configuration
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "kaplan_meier_train.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = ""

# Client Computing Executors.
executors = [
{
# tasks the executors are defined to handle
tasks = ["train"]

# This particular executor
executor {

# This is an executor for Client API. The underline data exchange is using Pipe.
path = "nvflare.app_opt.pt.client_api_launcher_executor.ClientAPILauncherExecutor"

args {
# launcher_id is used to locate the Launcher object in "components"
launcher_id = "launcher"

# pipe_id is used to locate the Pipe object in "components"
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 60

# format of the exchange parameters
params_exchange_format = "raw"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "FULL"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = false
}
}
}
],

# this defined an array of task data filters. If provided, it will control the data from server controller to client executor
task_data_filters = []

# this defined an array of task result filters. If provided, it will control the result from client executor to server controller
task_result_filters = []

components = [
{
# component id is "launcher"
id = "launcher"

# the class path of this component
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

args {
# the launcher will invoke the script
script = "python3 custom/{app_script} {app_config} "
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
},
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
# how fast should it read from the peer
read_interval = 0.1
}
},
{
# we use this component so the client api `flare.init()` can get required information
id = "config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = ["metric_relay"]
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
# version of the configuration
format_version = 2
task_data_filters =[]
task_result_filters = []

workflows = [
{
id = "km"
path = "kaplan_meier_wf.KM"
args {
min_clients = 5
}
}
]

components = []

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
description = "Kaplan-Meier survival analysis"
execution_api_type = "client_api"
controller_type = "server"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Job Template Information Card

## kaplan_meier
name = "kaplan_meier"
description = "Kaplan-Meier survival analysis"
class_name = "KM"
controller_type = "server"
executor_type = "launcher_executor"
contributor = "NVIDIA"
init_publish_date = "2024-04-09"
last_updated_date = "2024-04-30"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name = "kaplan_meier"
resource_spec {}
min_clients = 2
deploy_map {
app = [
"@ALL"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "kaplan_meier_train.py"
app_script = "kaplan_meier_train_he.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = "--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt"
app_config = ""

# Client Computing Executors.
executors = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
workflows = [
{
id = "km"
path = "kaplan_meier_wf.KM"
path = "kaplan_meier_wf_he.KM"
args {
min_clients = 5
he_context_path = "/tmp/flare/he_context/he_context_server.txt"
min_clients = 5
he_context_path = "/tmp/flare/he_context/he_context_server.txt"
}
}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
executor_type = "launcher_executor"
contributor = "NVIDIA"
init_publish_date = "2024-04-09"
last_updated_date = "2024-04-09"
last_updated_date = "2024-04-30"
Loading
Loading