Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example for mulitparty kaplan-meier analysis with HE #2259

Merged
merged 29 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6fb7a50
add example for mulitparty kaplan meier analysis with HE
ZiyueXu77 Jan 5, 2024
e54e6b5
update requirements
ZiyueXu77 Jan 5, 2024
87ecf77
update baseline script, remove complex settings and keep basic only
ZiyueXu77 Jan 5, 2024
6031902
add readme with details
ZiyueXu77 Jan 5, 2024
fb950ed
add readme with details
ZiyueXu77 Jan 5, 2024
c072d31
add curves, modify saving functions (curve and km details)
ZiyueXu77 Jan 5, 2024
2959b3f
job name update
ZiyueXu77 Jan 5, 2024
484cf3d
remove redundant print
ZiyueXu77 Jan 5, 2024
c927579
Merge branch 'main' into HE_KM
ZiyueXu77 Jan 8, 2024
f2e6d10
move data preparation part out of local code
ZiyueXu77 Jan 8, 2024
8737799
Merge branch 'main' into HE_KM
ZiyueXu77 Jan 8, 2024
8cb6d7d
move HE context part out of FL process to better accomodate the trans…
ZiyueXu77 Jan 10, 2024
3f5af59
Merge branch 'main' into HE_KM
ZiyueXu77 Jan 10, 2024
980fcbf
Merge branch 'NVIDIA:main' into HE_KM
ZiyueXu77 Jan 16, 2024
16fa0a0
Merge branch 'NVIDIA:main' into HE_KM
ZiyueXu77 Jan 29, 2024
10d82ab
Merge branch 'main' into HE_KM
SYangster Apr 2, 2024
3aa0919
Merge branch 'main' into HE_KM
SYangster Apr 4, 2024
8313738
update to use new controller interface
SYangster Apr 1, 2024
f67b606
change to send_model_and_wait
SYangster Apr 8, 2024
be90bdf
Merge branch 'main' into HE_KM
SYangster Apr 8, 2024
adb2449
format
SYangster Apr 8, 2024
cc0a56e
updated readme
SYangster Apr 8, 2024
be0b941
Merge branch 'main' into HE_KM
ZiyueXu77 Apr 8, 2024
6549ea5
fix merge conflict
SYangster Apr 8, 2024
fd62704
update readme
ZiyueXu77 Apr 9, 2024
479fb52
update readme
ZiyueXu77 Apr 9, 2024
f1df483
update readme
ZiyueXu77 Apr 9, 2024
096a202
update readme
ZiyueXu77 Apr 9, 2024
fafdd43
move to job template
SYangster Apr 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/advanced/kaplan-meier-he/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Secure Federated Kaplan-Meier Analysis via Homomorphic Encryption

This example illustrates two features:
* How to perform Kaplan-Meirer survival analysis in federated setting securely via Homomorphic Encryption (HE).
* How to use the Flare Workflow Communicator API to contract a workflow to facilitate HE under simulator mode.
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved

## Secure Multi-party Kaplan-Meier Analysis
Kaplan-Meier survival analysis is a one-shot (non-iterative) analysis performed on a list of events and their corresponding time. In this example, we use [lifelines](https://zenodo.org/records/10456828) to perform this analysis.

Essentially, the estimator needs to get access to the event list, and under the setting of federated analysis, the aggregated event list from all participants.

However, this poses a data security concern - by sharing the event list, the raw data can be exposed to external parties, which break the core value of federated analysis.

Therefore, we would like to design a secure mechanism to enable collaborative Kaplan-Meier analysis without the risk of exposing any raw information from a certain participant (at server end). This is achieved by two techniques:

- Condense the raw event list to two histograms (one for observed events and the other for censored event) binned at certain interval (e.g. a week), such that events happened within the same bin from different participants can be aggregated and will not be distinguishable for the final aggregated histograms.
- The local histograms will be encrypted as one single vector before sending to server, and the global aggregation operation at server side will be performed entirely within encryption space with HE.

With these two settings, the server will have no access to any knowledge regarding local submissions, and participants will only receive global aggregated histograms that will not contain distinguishable information regarding any individual participants (client number >= 3 - if only two participants, one can infer the other party's info by subtracting its own histograms).

The final Kaplan-Meier survival analysis will be performed locally on the global aggregated event list, recovered from global histograms.


## Simulated HE Analysis via FLARE Workflow Communicator API

The Flare Workflow Communicator API provides the functionality of customized message payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme.

Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) does not support [simulator mode](https://nvflare.readthedocs.io/en/main/getting_started.html), the main reason is that the HE context information (specs and keys) needs to be provisioned before initializing the federated job. For the same reason, it is not straightforward for users to try different HE schemes beyond our existing support for [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py).

With the Flare Workflow Communicator API, such "proof of concept" experiment becomes easy (of course, secure provisioning is still the way to go for real-life federated applications). In this example, the federated analysis pipeline includes 3 rounds:
1. Server generate and distribute the HE context to clients, and remove the private key on server side. Again, this step is done by secure provisioning for real-life applications, but for simulator, we use this step to distribute the HE context.
2. Clients collect the information of the local maximum time bin number and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable.
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients.

After Round 3, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information.

## Run the job
We first run a baseline analysis with full event information:
```commandline
python baseline_kaplan_meier.py
```
By default, this will generate a KM curve image `km_curve_baseline.png` under the current working directory.

Then we run a 5-client federated job with simulator, begin with splitting and generating the data files for each clients:
```commandline
python utils/prepare_data.py --out_path "/tmp/flare/dataset/km_data"
```
And we can run the federated job:
```commandline
nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he
```
By default, this will generate a KM curve image `km_curve_fl.png` under each client's directory.

## Display Result

By comparing the two curves, we can observe that the two are identical:
![KM survival baseline](figs/km_curve_baseline.png)
![KM survival fl](figs/km_curve_fl.png)
72 changes: 72 additions & 0 deletions examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import matplotlib.pyplot as plt
import numpy as np
from lifelines import KaplanMeierFitter
from sksurv.datasets import load_veterans_lung_cancer


def args_parser():
parser = argparse.ArgumentParser(description="Kaplan Meier Survival Analysis Baseline")
parser.add_argument(
"--output_curve_path",
type=str,
default="./km_curve_baseline.png",
help="save path for the output curve",
)
return parser


def prepare_data(bin_days: int = 7):
data_x, data_y = load_veterans_lung_cancer()
total_data_num = data_x.shape[0]
print(f"Total data count: {total_data_num}")
event = data_y["Status"]
time = data_y["Survival_in_days"]
# Categorize data to a bin, default is a week (7 days)
time = np.ceil(time / bin_days).astype(int)
return event, time


def main():
parser = args_parser()
args = parser.parse_args()

# Set parameters
output_curve_path = args.output_curve_path

# Generate data
event, time = prepare_data()

# Fit and plot Kaplan Meier curve with lifelines
kmf = KaplanMeierFitter()
# Fit the survival data
kmf.fit(time, event)
# Plot and save the Kaplan-Meier survival curve
plt.figure()
plt.title("Baseline")
kmf.plot_survival_function()
plt.ylim(0, 1)
plt.ylabel("prob")
plt.xlabel("time")
plt.legend("", frameon=False)
plt.tight_layout()
plt.savefig(output_curve_path)


if __name__ == "__main__":
main()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
{
SYangster marked this conversation as resolved.
Show resolved Hide resolved
# version of the configuration
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "kaplan_meier_train.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = "--data_root /tmp/flare/dataset/km_data"

# Client Computing Executors.
executors = [
{
# tasks the executors are defined to handle
tasks = ["train"]

# This particular executor
executor {

# This is an executor for Client API. The underline data exchange is using Pipe.
path = "nvflare.app_opt.pt.client_api_launcher_executor.ClientAPILauncherExecutor"

args {
# launcher_id is used to locate the Launcher object in "components"
launcher_id = "launcher"

# pipe_id is used to locate the Pipe object in "components"
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 60

# format of the exchange parameters
params_exchange_format = "raw"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "FULL"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = false
}
}
}
],

# this defined an array of task data filters. If provided, it will control the data from server controller to client executor
task_data_filters = []
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved

# this defined an array of task result filters. If provided, it will control the result from client executor to server controller
task_result_filters = []

components = [
{
# component id is "launcher"
id = "launcher"

# the class path of this component
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

args {
# the launcher will invoke the script
script = "python3 custom/{app_script} {app_config} "
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
},
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
# how fast should it read from the peer
read_interval = 0.1
}
},
{
# we use this component so the client api `flare.init()` can get required information
id = "config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = ["metric_relay"]
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
# version of the configuration
format_version = 2
task_data_filters =[]
task_result_filters = []

workflows = [
{
id = "km"
path = "nvflare.app_common.workflows.wf_controller.WFController"
args {
task_name = "train"
wf_class_path = "kaplan_meier_wf.KM",
wf_args {
min_clients = 5
}
}
}
]

components = []

}
Loading