Skip to content

Commit

Permalink
add readme with details
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Jan 5, 2024
1 parent 6031902 commit 07a6b32
Show file tree
Hide file tree
Showing 5 changed files with 510 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
{
# version of the configuration
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "km_train.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = ""

# Client Computing Executors.
executors = [
{
# tasks the executors are defined to handle
tasks = ["train"]

# This particular executor
executor {

# This is an executor for Client API. The underline data exchange is using Pipe.
path = "nvflare.app_opt.pt.client_api_launcher_executor.ClientAPILauncherExecutor"

args {
# launcher_id is used to locate the Launcher object in "components"
launcher_id = "launcher"

# pipe_id is used to locate the Pipe object in "components"
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 60

# format of the exchange parameters
params_exchange_format = "raw"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "FULL"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = false
}
}
}
],

# this defined an array of task data filters. If provided, it will control the data from server controller to client executor
task_data_filters = []

# this defined an array of task result filters. If provided, it will control the result from client executor to server controller
task_result_filters = []

components = [
{
# component id is "launcher"
id = "launcher"

# the class path of this component
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

args {
# the launcher will invoke the script
script = "python3 custom/{app_script} {app_config} "
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
},
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
# how fast should it read from the peer
read_interval = 0.1
}
},
{
# we use this component so the client api `flare.init()` can get required information
id = "config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = ["metric_relay"]
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
# version of the configuration
format_version = 2
task_data_filters =[]
task_result_filters = []

workflows = [
{
id = "km"
path = "nvflare.app_common.workflows.wf_controller.WFController"
args {
task_name = "train"
wf_class_path = "kaplan_meier_wf.KM",
wf_args {
min_clients = 2
}
}
}
]

components = []

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict

import tenseal as ts

from nvflare.app_common.abstract.fl_model import FLModel, ParamsType
from nvflare.app_common.workflows.wf_comm.wf_comm_api_spec import (
CURRENT_ROUND,
DATA,
MIN_RESPONSES,
NUM_ROUNDS,
START_ROUND,
)
from nvflare.app_common.workflows.wf_comm.wf_spec import WF

# Controller Workflow


class KM(WF):
def __init__(self, min_clients: int):
super(KM, self).__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.min_clients = min_clients
self.num_rounds = 3

def run(self):
he_context, max_idx_results = self.distribute_he_context_collect_max_idx()
global_res = self.aggr_max_idx(max_idx_results)
enc_hist_results = self.distribute_max_idx_collect_enc_stats(global_res)
hist_obs_global, hist_cen_global = self.aggr_he_hist(he_context, enc_hist_results)
_ = self.distribute_global_hist(hist_obs_global, hist_cen_global)

def distribute_he_context_collect_max_idx(self):
self.logger.info("send kaplan-meier analysis command to all sites with HE context \n")

context = ts.context(ts.SCHEME_TYPE.BFV, poly_modulus_degree=4096, plain_modulus=1032193)
context_serial = context.serialize(save_secret_key=True)
# drop private key for server
context.make_context_public()
# payload data always needs to be wrapped into an FLModel
model = FLModel(params={"he_context": context_serial}, params_type=ParamsType.FULL)

msg_payload = {
MIN_RESPONSES: self.min_clients,
CURRENT_ROUND: 1,
NUM_ROUNDS: self.num_rounds,
START_ROUND: 1,
DATA: model,
}

results = self.flare_comm.broadcast_and_wait(msg_payload)
return context, results

def aggr_max_idx(self, sag_result: Dict[str, Dict[str, FLModel]]):
self.logger.info("aggregate max histogram index \n")

if not sag_result:
raise RuntimeError("input is None or empty")

task_name, task_result = next(iter(sag_result.items()))

if not task_result:
raise RuntimeError("task_result None or empty ")

max_idx_global = []
for site, fl_model in task_result.items():
max_idx = fl_model.params["max_idx"]
print(max_idx)
max_idx_global.append(max_idx)
# actual time point as index, so plus 1 for storage
return max(max_idx_global) + 1

def distribute_max_idx_collect_enc_stats(self, result: int):
self.logger.info("send global max_index to all sites \n")

model = FLModel(params={"max_idx_global": result}, params_type=ParamsType.FULL)

msg_payload = {
MIN_RESPONSES: self.min_clients,
CURRENT_ROUND: 2,
NUM_ROUNDS: self.num_rounds,
START_ROUND: 1,
DATA: model,
}

results = self.flare_comm.broadcast_and_wait(msg_payload)
return results

def aggr_he_hist(self, he_context, sag_result: Dict[str, Dict[str, FLModel]]):
self.logger.info("aggregate histogram within HE \n")

if not sag_result:
raise RuntimeError("input is None or empty")

task_name, task_result = next(iter(sag_result.items()))

if not task_result:
raise RuntimeError("task_result None or empty ")

hist_obs_global = None
hist_cen_global = None
for site, fl_model in task_result.items():
hist_obs_he_serial = fl_model.params["hist_obs"]
hist_obs_he = ts.bfv_vector_from(he_context, hist_obs_he_serial)
hist_cen_he_serial = fl_model.params["hist_cen"]
hist_cen_he = ts.bfv_vector_from(he_context, hist_cen_he_serial)

if not hist_obs_global:
print(f"assign global hist with result from {site}")
hist_obs_global = hist_obs_he
else:
print(f"add to global hist with result from {site}")
hist_obs_global += hist_obs_he

if not hist_cen_global:
print(f"assign global hist with result from {site}")
hist_cen_global = hist_cen_he
else:
print(f"add to global hist with result from {site}")
hist_cen_global += hist_cen_he

# return the two accumulated vectors, serialized for transmission
hist_obs_global_serial = hist_obs_global.serialize()
hist_cen_global_serial = hist_cen_global.serialize()
return hist_obs_global_serial, hist_cen_global_serial

def distribute_global_hist(self, hist_obs_global_serial, hist_cen_global_serial):
self.logger.info("send global accumulated histograms within HE to all sites \n")

model = FLModel(
params={"hist_obs_global": hist_obs_global_serial, "hist_cen_global": hist_cen_global_serial},
params_type=ParamsType.FULL,
)

msg_payload = {
MIN_RESPONSES: self.min_clients,
CURRENT_ROUND: 3,
NUM_ROUNDS: self.num_rounds,
START_ROUND: 1,
DATA: model,
}

results = self.flare_comm.broadcast_and_wait(msg_payload)
return results
Loading

0 comments on commit 07a6b32

Please sign in to comment.