diff --git a/examples/advanced/lr-newton-raphson/README.md b/examples/advanced/lr-newton-raphson/README.md new file mode 100644 index 0000000000..4c55208b97 --- /dev/null +++ b/examples/advanced/lr-newton-raphson/README.md @@ -0,0 +1,246 @@ +# Federated Logistic Regression with Second-Order Newton-Raphson optimization + +This example shows how to implement a federated binary +classification via logistic regression with second-order Newton-Raphson optimization. + +The [UCI Heart Disease +dataset](https://archive.ics.uci.edu/dataset/45/heart+disease) is +used in this example. Scripts are provided to download and process the +dataset as described +[here](https://github.com/owkin/FLamby/tree/main/flamby/datasets/fed_heart_disease). This +dataset contains samples from 4 sites, splitted into training and +testing sets as described below: +|site | sample split | +|-------------|---------------------------------------| +|Cleveland | train: 199 samples, test: 104 samples | +|Hungary | train: 172 samples, test: 89 samples | +|Switzerland | train: 30 samples, test: 16 samples | +|Long Beach V | train: 85 samples, test: 45 samples | + +The number of features in each sample is 13. + +## Introduction + +The [Newton-Raphson +optimization](https://en.wikipedia.org/wiki/Newton%27s_method) problem +can be described as follows. + +In a binary classification task with logistic regression, the +probability of a data sample $x$ classified as positive is formulated +as: +$$p(x) = \sigma(\beta \cdot x + \beta_{0})$$ +where $\sigma(.)$ denotes the sigmoid function. We can incorporate +$\beta_{0}$ and $\beta$ into a single parameter vector $\theta = +( \beta_{0}, \beta)$. Let $d$ be the number +of features for each data sample $x$ and let $N$ be the number of data +samples. We then have the matrix version of the above probability +equation: +$$p(X) = \sigma( X \theta )$$ +Here $X$ is the matrix of all samples, with shape $N \times (d+1)$, +having it's first column filled with value 1 to account for the +intercept $\theta_{0}$. + +The goal is to compute parameter vector $\theta$ that maximizes the +below likelihood function: +$$L_{\theta} = \prod_{i=1}^{N} p(x_i)^{y_i} (1 - p(x_i)^{1-y_i})$$ + +The Newton-Raphson method optimizes the likelihood function via +quadratic approximation. Omitting the maths, the theoretical update +formula for parameter vector $\theta$ is: +$$\theta^{n+1} = \theta^{n} - H_{\theta^{n}}^{-1} \nabla L_{\theta^{n}}$$ +where +$$\nabla L_{\theta^{n}} = X^{T}(y - p(X))$$ +is the gradient of the likelihood function, with $y$ being the vector +of ground truth for sample data matrix $X$, and +$$H_{\theta^{n}} = -X^{T} D X$$ +is the Hessian of the likelihood function, with $D$ a diagonal matrix +where diagonal value at $(i,i)$ is $D(i,i) = p(x_i) (1 - p(x_i))$. + +In federated Newton-Raphson optimization, each client will compute its +own gradient $\nabla L_{\theta^{n}}$ and Hessian $H_{\theta^{n}}$ +based on local training samples. A server will aggregate the gradients +and Hessians computed from all clients, and perform the update of +parameter $\theta$ based on the theoretical update formula described +above. + +## Implementation + +Using `nvflare`, The federated logistic regression with Newton-Raphson +optimization is implemented as follows. + +On the server side, all workflow logics are implemented in +class `FedAvgNewtonRaphson`, which can be found +[here](job/newton_raphson/app/custom/newton_raphson_workflow.py). The +`FedAvgNewtonRaphson` class inherits from the +[`BaseFedAvg`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/base_fedavg.py) +class, which itself inherits from the **Workflow Controller** +([`WFController`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/wf_controller.py)) +class. This is the preferrable approach to implement a custom +workflow, since `WFController` decouples communication logic from +actual workflow (training & validation) logic. The mandatory +method to override in `WFController` is the +[`run()`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/wf_controller.py#L37) +method, where the orchestration of server-side workflow actually +happens. The implementation of `run()` method in +[`FedAvgNewtonRaphson`](job/newton_raphson/app/custom/newton_raphson_workflow.py) +is similar to the classic +[`FedAvg`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/fedavg.py#L44): +- Initialize the global model, this is acheived through method `load_model()` + from base class + [`ModelController`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/model_controller.py#L292), + which relies on the + [`ModelPersistor`](https://nvflare.readthedocs.io/en/main/glossary.html#persistor). A + custom + [`NewtonRaphsonModelPersistor`](job/newton_raphson/app/custom/newton_raphson_persistor.py) + is implemented in this example, which is based on the + [`NPModelPersistor`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/np/np_model_persistor.py) + for numpy data, since the _model_ in the case of logistic regression + is just the parameter vector $\theta$ that can be represented by a + numpy array. Only the `__init__` method needs to be re-implemented + to provide a proper initialization for the global parameter vector + $\theta$. +- During each training round, the global model will be sent to the + list of participating clients to perform a training task. This is + done using the + [`send_model_and_wait()`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/workflows/wf_controller.py#L41) + method. Once + the clients finish their local training, results will be collected + and sent back to server as + [`FLModel`](https://nvflare.readthedocs.io/en/main/programming_guide/fl_model.html#flmodel)s. +- Results sent by clients contain their locally computed gradient and + Hessian. A [custom aggregation + function](job/newton_raphson/app/custom/newton_raphson_workflow.py) + is implemented to get the averaged gradient and Hessian, and compute + the Newton-Raphson update for the global parameter vector $\theta$, + based on the theoretical formula shown above. The averaging of + gradient and Hessian is based on the + [`WeightedAggregationHelper`](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_common/aggregators/weighted_aggregation_helper.py#L20), + which weighs the contribution from each client based on the number + of local training samples. The aggregated Newton-Raphson update is + returned as an `FLModel`. +- After getting the aggregated Newton-Raphson update, an + [`update_model()`](job/newton_raphson/app/custom/newton_raphson_workflow.py#L172) + method is implemented to actually apply the Newton-Raphson update to + the global model. +- The last step is to save the updated global model, again through + the `NewtonRaphsonModelPersistor` using `save_model()`. + + +On the client side, the local training logic is implemented +[here](job/newton_raphson/app/custom/newton_raphson_train.py). The +implementation is based on the [`Client +API`](https://nvflare.readthedocs.io/en/main/programming_guide/execution_api_type.html#client-api). This +allows user to add minimum `nvflare`-specific codes to turn a typical +centralized training script to a federated client side local training +script. +- During local training, each client receives a copy of the global + model, sent by the server, using `flare.receive()` API. The received + global model is an instance of `FLModel`. +- A local validation is first performed, where validation metrics + (accuracy and precision) are streamed to server using the + [`SummaryWriter`](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.client.tracking.html#nvflare.client.tracking.SummaryWriter). The + streamed metrics can be loaded and visualized using tensorboard. +- Then each client computes it's gradient and Hessian based on local + training data, using their respective theoretical formula described + above. This is implemented in the + [`train_newton_raphson()`](job/newton_raphson/app/custom/newton_raphson_train.py#L82) + method. Each client then sends the computed results (always in + `FLModel` format) to server for aggregation, using `flare.send()` + API. + +Each client site corresponds to a site listed in the data table above. + +A [centralized training script](./train_centralized.py) is also +provided, which allows for comparing the federated Newton-Raphson +optimization versus the centralized version. In the centralized +version, training data samples from all 4 sites were concatenated into +a single matrix, used to optimize the model parameters. The +optimized model was then tested separately on testing data samples of +the 4 sites, using accuracy and precision as metrics. + +Comparing the federated [client-side training +code](job/newton_raphson/app/custom/newton_raphson_train.py) with the +centralized [training code](./train_centralized.py), we can see that +the training logic remains similar: load data, perform training +(Newton-Raphson updates), and valid trained model. The only added +differences in the federated code are related to interaction with the +FL system, such as receiving and send `FLModel`. + +## Set Up Environment & Install Dependencies + +Follow instructions +[here](https://github.com/NVIDIA/NVFlare/tree/main/examples#set-up-a-virtual-environment) +to set up a virtual environment for `nvflare` examples and install +dependencies for this example. + +## Download and prepare data + +Execute the following script +``` +bash ./prepare_heart_disease_data.sh +``` +This will download the heart disease dataset under +`/tmp/flare/dataset/heart_disease_data/` + +## Centralized Logistic Regression + +Launch the following script: +``` +python ./train_centralized.py --solver custom +``` + +Two implementations of logistic regression are provided in the +centralized training script, which can be specified by the `--solver` +argument: +- One is using `sklearn.LogisticRegression` with `newton-cholesky` + solver +- The other one is manually implemented using the theoretical update + formulas described above. + +Both implementations were tested to converge in 4 iterations and to +give the same result. + +Example output: +``` +using solver: custom +loading training data. +training data X loaded. shape: (486, 13) +training data y loaded. shape: (486, 1) + +site - 1 +validation set n_samples: 104 +accuracy: 0.75 +precision: 0.7115384615384616 + +site - 2 +validation set n_samples: 89 +accuracy: 0.7528089887640449 +precision: 0.6122448979591837 + +site - 3 +validation set n_samples: 16 +accuracy: 0.75 +precision: 1.0 + +site - 4 +validation set n_samples: 45 +accuracy: 0.6 +precision: 0.9047619047619048 +``` + +## Federated Logistic Regression + +Execute the following command to launch federated logistic +regression. This will run in `nvflare`'s simulator mode. +``` +nvflare simulator -w ./workspace -n 4 -t 4 job/newton_raphson/ +``` + +Accuracy and precision for each site can be viewed in Tensorboard: +``` +tensorboard --logdir=./workspace/server/simulate_job/tb_events +``` +As can be seen from the figure below, per-site evaluation metrics in +federated logistic regression are on-par with the centralized version. + +Tensorboard metrics server diff --git a/examples/advanced/lr-newton-raphson/figs/tb-metrics.png b/examples/advanced/lr-newton-raphson/figs/tb-metrics.png new file mode 100644 index 0000000000..148bacea0f Binary files /dev/null and b/examples/advanced/lr-newton-raphson/figs/tb-metrics.png differ diff --git a/examples/advanced/lr-newton-raphson/job/newton_raphson/app/config/config_fed_client.json b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/config/config_fed_client.json new file mode 100755 index 0000000000..75413266b0 --- /dev/null +++ b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/config/config_fed_client.json @@ -0,0 +1,75 @@ +{ + "format_version": 2, + "app_script": "newton_raphson_train.py", + "app_config": "--data_root /tmp/flare/dataset/heart_disease_data", + "executors": [ + { + "tasks": [ + "train" + ], + "executor": { + "path": "nvflare.app_common.executors.client_api_launcher_executor.ClientAPILauncherExecutor", + "args": { + "launcher_id": "launcher", + "pipe_id": "pipe", + "heartbeat_timeout": 60, + "params_exchange_format": "raw", + "params_transfer_type": "FULL", + "train_with_evaluation": false + } + } + } + ], + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "launcher", + "path": "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher", + "args": { + "script": "python3 custom/{app_script} {app_config}", + "launch_once": true + } + }, + { + "id": "pipe", + "path": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe", + "args": { + "mode": "PASSIVE", + "site_name": "{SITE_NAME}", + "token": "{JOB_ID}", + "root_url": "{ROOT_URL}", + "secure_mode": "{SECURE_MODE}", + "workspace_dir": "{WORKSPACE}" + } + }, + { + "id": "metrics_pipe", + "path": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe", + "args": { + "mode": "PASSIVE", + "site_name": "{SITE_NAME}", + "token": "{JOB_ID}", + "root_url": "{ROOT_URL}", + "secure_mode": "{SECURE_MODE}", + "workspace_dir": "{WORKSPACE}" + } + }, + { + "id": "metric_relay", + "path": "nvflare.app_common.widgets.metric_relay.MetricRelay", + "args": { + "pipe_id": "metrics_pipe", + "event_type": "fed.analytix_log_stats", + "read_interval": 0.1 + } + }, + { + "id": "client_api_config_preparer", + "path": "nvflare.app_common.widgets.external_configurator.ExternalConfigurator", + "args": { + "component_ids": ["metric_relay"] + } + } + ] +} diff --git a/examples/advanced/lr-newton-raphson/job/newton_raphson/app/config/config_fed_server.json b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/config/config_fed_server.json new file mode 100755 index 0000000000..d924e556d2 --- /dev/null +++ b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/config/config_fed_server.json @@ -0,0 +1,34 @@ +{ + "format_version": 2, + "server": { + "heart_beat_timeout": 600 + }, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "newton_raphson_persistor", + "path": "newton_raphson_persistor.NewtonRaphsonModelPersistor", + "args": { + "n_features": 13 + } + }, + { + "id": "tb_analytics_receiver", + "path": "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver", + "args.events": ["fed.analytix_log_stats"] + } + ], + "workflows": [ + { + "id": "fedavg_newton_raphson", + "path": "newton_raphson_workflow.FedAvgNewtonRaphson", + "args": { + "min_clients": 4, + "num_rounds": 5, + "damping_factor": 0.8, + "persistor_id": "newton_raphson_persistor" + } + } + ] +} diff --git a/examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_persistor.py b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_persistor.py new file mode 100644 index 0000000000..5b324dd50c --- /dev/null +++ b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_persistor.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +from nvflare.app_common.np.np_model_persistor import NPModelPersistor + + +class NewtonRaphsonModelPersistor(NPModelPersistor): + """ + This class defines the persistor for Newton Raphson model. + + A persistor controls the logic behind initializing, loading + and saving of the model / parameters for each round of a + federated learning process. + + In the 2nd order Newton Raphson case, a model is just a + 1-D numpy vector containing the parameters for logistic + regression. The length of the parameter vector is defined + by the number of features in the dataset. + + """ + + def __init__(self, model_dir="models", model_name="weights.npy", n_features=13): + """ + Init function for NewtonRaphsonModelPersistor. + + Args: + model_dir: sub-folder name to save and load the global model + between rounds. + model_name: name to save and load the global model. + n_features: number of features for the logistic regression. + For the UCI ML heart Disease dataset, this is 13. + + """ + + super().__init__() + + self.model_dir = model_dir + self.model_name = model_name + self.n_features = n_features + + # A default model is loaded when no local model is available. + # This happen when training starts. + # + # A `model` for a binary logistic regression is just a matrix, + # with shape (n_features + 1, 1). + # For the UCI ML Heart Disease dataset, the n_features = 13. + # + # A default matrix with value 0s is created. + # + self.default_data = np.zeros((self.n_features + 1, 1), dtype=np.float32) diff --git a/examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_train.py b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_train.py new file mode 100644 index 0000000000..419b9ed70b --- /dev/null +++ b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_train.py @@ -0,0 +1,184 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +import numpy as np +from sklearn.metrics import accuracy_score, precision_score + +import nvflare.client as flare +from nvflare.apis.fl_constant import FLMetaKey +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType +from nvflare.app_common.np.constants import NPConstants +from nvflare.client.tracking import SummaryWriter + + +def parse_arguments(): + """ + Parse command line args for client side training. + """ + parser = argparse.ArgumentParser(description="Federated Second-Order Newton Raphson") + + parser.add_argument("--data_root", type=str, help="Path to load client side data.") + + return parser.parse_args() + + +def load_data(data_root, site_name): + """ + Load the data for each client. + + Args: + data_root: root directory storing client site data. + site_name: client site name + Returns: + A dict with client site training and validation data. + """ + print("loading data for client {} from: {}".format(site_name, data_root)) + train_x_path = os.path.join(data_root, "{}.train.x.npy".format(site_name)) + train_y_path = os.path.join(data_root, "{}.train.y.npy".format(site_name)) + test_x_path = os.path.join(data_root, "{}.test.x.npy".format(site_name)) + test_y_path = os.path.join(data_root, "{}.test.y.npy".format(site_name)) + + train_X = np.load(train_x_path) + train_y = np.load(train_y_path) + valid_X = np.load(test_x_path) + valid_y = np.load(test_y_path) + + return {"train_X": train_X, "train_y": train_y, "valid_X": valid_X, "valid_y": valid_y} + + +def sigmoid(inp): + return 1.0 / (1.0 + np.exp(-inp)) + + +def train_newton_raphson(data, theta): + """ + Compute gradient and hessian on local data + based on paramters received from server. + + """ + train_X = data["train_X"] + train_y = data["train_y"] + + # Add intercept, pre-pend 1s to as first + # column of train_X + train_X = np.concatenate((np.ones((train_X.shape[0], 1)), train_X), axis=1) + + # Compute probabilities from current weights + proba = sigmoid(np.dot(train_X, theta)) + + # The gradient is X^T . (y - proba) + gradient = np.dot(train_X.T, (train_y - proba)) + + # The hessian is X^T . D . X, where D is the + # diagnoal matrix with values proba * (1 - proba) + D = np.diag((proba * (1 - proba))[:, 0]) + hessian = train_X.T.dot(D).dot(train_X) + + return {"gradient": gradient, "hessian": hessian} + + +def validate(data, theta): + """ + Performs local validation. + Computes accuracy and precision scores. + + """ + valid_X = data["valid_X"] + valid_y = data["valid_y"] + + # Add intercept, pre-pend 1s to as first + # column of valid_X + valid_X = np.concatenate((np.ones((valid_X.shape[0], 1)), valid_X), axis=1) + + # Compute probabilities from current weights + proba = sigmoid(np.dot(valid_X, theta)) + + return {"accuracy": accuracy_score(valid_y, proba.round()), "precision": precision_score(valid_y, proba.round())} + + +def main(): + """ + This is a typical ML training loop, + augmented with Flare Client API to + perform local training on each client + side and send result to server. + + """ + args = parse_arguments() + + flare.init() + + site_name = flare.get_site_name() + print("training on client site: {}".format(site_name)) + + # Load client site data. + data = load_data(args.data_root, site_name) + + # Get metric summary writer + writer = SummaryWriter() + + while flare.is_running(): + + # Receive global model (FLModel) from server. + global_model = flare.receive() + + curr_round = global_model.current_round + print("current_round={}".format(curr_round)) + + print( + ("[ROUND {}] - client site: {}, received " "global model: {}").format(curr_round, site_name, global_model) + ) + + # Get the weights, aka parameter theta for + # logistic regression. + global_weights = global_model.params[NPConstants.NUMPY_KEY] + print("[ROUND {}] - global model weights: {}".format(curr_round, global_weights)) + + # Local validation before training + print(("[ROUND {}] - start validation of global " "model on client: {}").format(curr_round, site_name)) + validation_scores = validate(data, global_weights) + print( + ("[ROUND {}] - validation metric scores on " "client: {} = {}").format( + curr_round, site_name, validation_scores + ) + ) + + # Write validation metric summary + writer.add_scalar("{}/accuracy".format(site_name), validation_scores["accuracy"], curr_round) + + writer.add_scalar("{}/precision".format(site_name), validation_scores["precision"], curr_round) + + # Local training + print(("[ROUND {}] - start local training on client " "site: {}").format(curr_round, site_name)) + result_dict = train_newton_raphson(data, theta=global_weights) + + # Send result to server for aggregation. + result_model = FLModel(params=result_dict, params_type=ParamsType.FULL) + result_model.meta[FLMetaKey.NUM_STEPS_CURRENT_ROUND] = data["train_X"].shape[0] + + print( + ( + "[ROUND {}] - local newton raphson training from " "client: {} complete, sending results to server: {}" + ).format(curr_round, site_name, result_model) + ) + + flare.send(result_model) + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_workflow.py b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_workflow.py new file mode 100644 index 0000000000..56c4d87a46 --- /dev/null +++ b/examples/advanced/lr-newton-raphson/job/newton_raphson/app/custom/newton_raphson_workflow.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +import numpy as np + +from nvflare.apis.fl_constant import FLMetaKey +from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.np.constants import NPConstants +from nvflare.app_common.workflows.base_fedavg import BaseFedAvg + + +class FedAvgNewtonRaphson(BaseFedAvg): + def __init__(self, damping_factor, epsilon=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + """ + Init function for FedAvgNewtonRaphson. + + Args: + damping_factor: damping factor for Newton Raphson updates. + epsilon: a regularization factor to avoid empty hessian for + matrix inversion + """ + self.damping_factor = damping_factor + self.epsilon = epsilon + self.aggregator = WeightedAggregationHelper() + + def run(self) -> None: + """ + The run function executes the logic of federated + second order Newton Raphson optimization. + + """ + self.info("starting Federated Averaging Netwon Raphson ...") + + # First load the model and set up some training params. + # A `persisitor` (NewtonRaphsonModelPersistor) will load + # the model in `ModelLearnable` format, then will be + # converted `FLModel` by `ModelController`. + # + model = self.load_model() + + model.start_round = self.start_round + model.total_rounds = self.num_rounds + + self.info("Server side model loader: {}".format(model)) + + for self.current_round in range(self.start_round, self.start_round + self.num_rounds): + self.info(f"Round {self.current_round} started.") + + # Get the list of clients. + clients = self.sample_clients(self.min_clients) + + model.current_round = self.current_round + + # Send training task and current global model to clients. + # + # A `task` isntance will be created, and sent + # to clients, the model is first converted to a shareable + # and is attached to the task. + # + # After the task is finished, the result (shareable) recieved + # from the task is converted to FLModel, and is returned to the + # server. The `results` below is a list with result (FLModel) + # from all clients. + # + # The full logic of `task` is implemented in: + # https://github.com/NVIDIA/NVFlare/blob/d6827bca96d332adb3402ceceb4b67e876146067/nvflare/app_common/workflows/model_controller.py#L178 + # + self.info("sending server side global model to clients") + results = self.send_model_and_wait(targets=clients, data=model) + + # Aggregate results receieved from clients. + aggregate_results = self.aggregate(results, aggregate_fn=self.newton_raphson_aggregator_fn) + + # Update global model based on the following formula: + # weights = weights + updates, where + # updates = -damping_factor * Hessian^{-1} . Gradient + self.update_model(model, aggregate_results) + + # Save global model. + self.save_model(model) + + self.info("Finished FedAvg.") + + def newton_raphson_aggregator_fn(self, results: List[FLModel]): + """ + Custom aggregator function for second order Newton Raphson + optimization. + + This uses the default thread-safe WeightedAggregationHelper, + which implement a weighted average of all values received from + a `result` dictionary. + + Args: + results: a list of `FLModel`s. Each `FLModel` is received + from a client. The field `params` is a dictionary that + contains values to be aggregated: the gradient and hessian. + """ + self.info("receieved results from clients: {}".format(results)) + + # On client side the `NUM_STEPS_CURRENT_ROUND` key + # is used to track the number of samples for each client. + for curr_result in results: + self.aggregator.add( + data=curr_result.params, + weight=curr_result.meta.get(FLMetaKey.NUM_STEPS_CURRENT_ROUND, 1.0), + contributor_name=curr_result.meta.get("client_name", AppConstants.CLIENT_UNKNOWN), + contribution_round=curr_result.current_round, + ) + + aggregated_dict = self.aggregator.get_result() + self.info("aggregated result: {}".format(aggregated_dict)) + + # Compute global model update: + # update = - damping_factor * Hessian^{-1} . Gradient + # A regularization is added to avoid empty hessian. + # + reg = self.epsilon * np.eye(aggregated_dict["hessian"].shape[0]) + newton_raphson_updates = self.damping_factor * np.linalg.solve( + aggregated_dict["hessian"] + reg, aggregated_dict["gradient"] + ) + self.info("newton raphson updates: {}".format(newton_raphson_updates)) + + # Convert the aggregated result to `FLModel`, this `FLModel` + # will then be used by `update_model` method from the base class, + # to update the global model weights. + # + aggr_result = FLModel( + params={"newton_raphson_updates": newton_raphson_updates}, + params_type=results[0].params_type, + meta={ + "nr_aggregated": len(results), + AppConstants.CURRENT_ROUND: results[0].current_round, + AppConstants.NUM_ROUNDS: self.num_rounds, + }, + ) + return aggr_result + + def update_model(self, model, model_update, replace_meta=True) -> FLModel: + """ + Update logistic regression parameters based on + aggregated gradient and hessian. + + """ + if replace_meta: + model.meta = model_update.meta + else: + model.meta.update(model_update.meta) + + model.metrics = model_update.metrics + model.params[NPConstants.NUMPY_KEY] += model_update.params["newton_raphson_updates"] diff --git a/examples/advanced/lr-newton-raphson/job/newton_raphson/meta.json b/examples/advanced/lr-newton-raphson/job/newton_raphson/meta.json new file mode 100644 index 0000000000..c157e9f65a --- /dev/null +++ b/examples/advanced/lr-newton-raphson/job/newton_raphson/meta.json @@ -0,0 +1,10 @@ +{ + "name": "newton_raphson", + "resource_spec": {}, + "min_clients" : 4, + "deploy_map": { + "app": [ + "@ALL" + ] + } +} diff --git a/examples/advanced/lr-newton-raphson/prepare_heart_disease_data.sh b/examples/advanced/lr-newton-raphson/prepare_heart_disease_data.sh new file mode 100755 index 0000000000..c297f15e71 --- /dev/null +++ b/examples/advanced/lr-newton-raphson/prepare_heart_disease_data.sh @@ -0,0 +1,29 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +DATA_DIR=/tmp/flare/dataset/heart_disease_data + +# Install dependencies +#pip install wget +FLAMBY_INSTALL_DIR=$(python3 -c "import sysconfig; print(sysconfig.get_path('purelib'))") +# git clone https://github.com/owkin/FLamby.git && cd FLamby && pip install -e . + +# Download data using FLamby +mkdir -p ${DATA_DIR} +python3 ${FLAMBY_INSTALL_DIR}/flamby/datasets/fed_heart_disease/dataset_creation_scripts/download.py --output-folder ${DATA_DIR} + +# Convert data to numpy files +python3 ${SCRIPT_DIR}/utils/convert_data_to_np.py ${DATA_DIR} diff --git a/examples/advanced/lr-newton-raphson/requirements.txt b/examples/advanced/lr-newton-raphson/requirements.txt new file mode 100644 index 0000000000..513c8f8be0 --- /dev/null +++ b/examples/advanced/lr-newton-raphson/requirements.txt @@ -0,0 +1,2 @@ +flamby @ git+https://github.com/owkin/FLamby.git@main +wget==3.2 diff --git a/examples/advanced/lr-newton-raphson/train_centralized.py b/examples/advanced/lr-newton-raphson/train_centralized.py new file mode 100755 index 0000000000..c64ee6cb6e --- /dev/null +++ b/examples/advanced/lr-newton-raphson/train_centralized.py @@ -0,0 +1,118 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +import numpy as np +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, precision_score + +DATA_ROOT = "/tmp/flare/dataset/heart_disease_data/" + +MAX_ITERS = 4 +EPSILON = 1.0 + + +def sigmoid(inp): + return 1.0 / (1.0 + np.exp(-inp)) + + +def lr_solver(X, y): + """ + Custom logistic regression solver using Newton Raphson + method. + + """ + n_features = X.shape[1] + theta = np.zeros((n_features + 1, 1)) + X = np.concatenate((np.ones((X.shape[0], 1)), X), axis=1) + + for iter in range(MAX_ITERS): + proba = sigmoid(np.dot(X, theta)) + gradient = np.dot(X.T, (y - proba)) + D = np.diag((proba * (1 - proba))[:, 0]) + hessian = X.T.dot(D).dot(X) + + reg = EPSILON * np.eye(hessian.shape[0]) + updates = np.linalg.solve(hessian + reg, gradient) + + theta += updates + + return theta + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--solver", + type=str, + default="custom", + help=("which solver to use: custom (default) or sklearn " "LogisticRegression. The results are the same. "), + ) + args = parser.parse_args() + + print("using solver:", args.solver) + + print("loading training data.") + train_X = np.concatenate( + ( + np.load(os.path.join(DATA_ROOT, "site-1.train.x.npy")), + np.load(os.path.join(DATA_ROOT, "site-2.train.x.npy")), + np.load(os.path.join(DATA_ROOT, "site-3.train.x.npy")), + np.load(os.path.join(DATA_ROOT, "site-4.train.x.npy")), + ) + ) + train_y = np.concatenate( + ( + np.load(os.path.join(DATA_ROOT, "site-1.train.y.npy")), + np.load(os.path.join(DATA_ROOT, "site-2.train.y.npy")), + np.load(os.path.join(DATA_ROOT, "site-3.train.y.npy")), + np.load(os.path.join(DATA_ROOT, "site-4.train.y.npy")), + ) + ) + +if args.solver == "sklearn": + train_y = train_y.reshape(-1) + +print("training data X loaded. shape:", train_X.shape) +print("training data y loaded. shape:", train_y.shape) + +if args.solver == "sklearn": + clf = LogisticRegression(random_state=0, solver="newton-cholesky", verbose=1).fit(train_X, train_y) + +else: + theta = lr_solver(train_X, train_y) + +for site in range(4): + + print("\nsite - {}".format(site + 1)) + test_X = np.load(os.path.join(DATA_ROOT, "site-{}.test.x.npy".format(site + 1))) + test_y = np.load(os.path.join(DATA_ROOT, "site-{}.test.y.npy".format(site + 1))) + test_y = test_y.reshape(-1) + + print("validation set n_samples: ", test_X.shape[0]) + + if args.solver == "sklearn": + proba = clf.predict_proba(test_X) + proba = proba[:, 1] + + else: + test_X = np.concatenate((np.ones((test_X.shape[0], 1)), test_X), axis=1) + proba = sigmoid(np.dot(test_X, theta)) + + print("accuracy:", accuracy_score(test_y, proba.round())) + print("precision:", precision_score(test_y, proba.round())) diff --git a/examples/advanced/lr-newton-raphson/utils/convert_data_to_np.py b/examples/advanced/lr-newton-raphson/utils/convert_data_to_np.py new file mode 100755 index 0000000000..a35ba16084 --- /dev/null +++ b/examples/advanced/lr-newton-raphson/utils/convert_data_to_np.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +import numpy as np +from flamby.datasets.fed_heart_disease import FedHeartDisease +from torch.utils.data import DataLoader as dl + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("save UCI Heart Disease as numpy arrays.") + parser.add_argument("save_dir", type=str, help="directory to save converted numpy arrays as .npy files.") + args = parser.parse_args() + + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir, exist_ok=True) + + for site in range(4): + + for flag in ("train", "test"): + + # To load data a pytorch dataset + data = FedHeartDisease(center=site, train=(flag == "train")) + + # Save training dataset + data_x = [] + data_y = [] + for x, y in dl(data, batch_size=1, shuffle=False, num_workers=0): + data_x.append(x.cpu().numpy().reshape(-1)) + data_y.append(y.cpu().numpy().reshape(-1)) + + data_x = np.array(data_x).reshape(-1, 13) + data_y = np.array(data_y).reshape(-1, 1) + + print("site {} - {} - variables shape: {}".format(site, flag, data_x.shape)) + print("site {} - {} - outcomes shape: {}".format(site, flag, data_y.shape)) + + save_x_path = "{}/site-{}.{}.x.npy".format(args.save_dir, site + 1, flag) + print("saving data: {}".format(save_x_path)) + np.save(save_x_path, data_x) + + save_y_path = "{}/site-{}.{}.y.npy".format(args.save_dir, site + 1, flag) + print("saving data: {}".format(save_y_path)) + np.save(save_y_path, data_y)