diff --git a/docs/real_world_fl.rst b/docs/real_world_fl.rst index 3ff6cf1f77..ed8190cdf1 100644 --- a/docs/real_world_fl.rst +++ b/docs/real_world_fl.rst @@ -28,4 +28,5 @@ to see the capabilities of the system and how it can be operated. real_world_fl/job real_world_fl/workspace real_world_fl/cloud_deployment + real_world_fl/notes_on_large_models user_guide/federated_authorization diff --git a/docs/real_world_fl/notes_on_large_models.rst b/docs/real_world_fl/notes_on_large_models.rst new file mode 100644 index 0000000000..fdb26a9fab --- /dev/null +++ b/docs/real_world_fl/notes_on_large_models.rst @@ -0,0 +1,90 @@ +.. _notes_on_large_models: + +Large Models +============ +As the federated learning tasks become more and more complex, their model sizes increase. Some model sizes may go beyond 2GB and even reach hundreds of GB. NVIDIA FLARE supports +large models as long as the system memory of servers and clients is capable of handling it. However, it requires special consideration on NVIDIA FLARE configuration and the system because +the network bandwidth and thus the time to transmit such large amount of data during NVIDIA FLARE job runtime varies significantly. Here we describe 128GB model training jobs to highlight +the configuration and system that users should consider for sucessful large model federated learning jobs. + +System Deployment +***************** +Our successful experiments of 128GB model training were running on one NVIDIA FLARE server and two clients. The server was deployed in Azure west-us region. One of those two clients +was deployed in AWS west-us-2 region and the other was in AWS ap-south-1 region. The system was deployed in such cross-region and cross-cloud-service-provider manner so that we can test +NVIDIA FLARE system with various conditions on the network bandwidth. +The Azure VM size of the NVIDIA FLARE server was M32-8ms, which has 875GB memory. The AWS EC2 instance type of NVIDIA FLARE clients was r5a.16xlarge with 512GB memory. We also enabled +128GB swap space on all machines. + +Job of 128GB Models +******************* +We slightly modified the hello-numpy example to generate a model, which was a dictionary of 64 keys. Each key held a 2GB numpy array. The local training task was to add a small number to +those numpy arrays. The aggregator on the server side was not changed. This job required at least two clients and ran 3 rounds to finish. + +Please note if your model contains leaf nodes that are larger than 4GB, the type of those nodes must be bytes. In that case, the outgoing model will need a conversion similar to the following: + +.. code:: python + + for k in np_data: + self.log_info(fl_ctx, f"converting {k=}") + tmp_file = io.BytesIO() + np.save(tmp_file, np_data[k]) + np_data[k] = tmp_file.getvalue() + tmp_file.close() + self.log_info(fl_ctx, f"done converting {k=}") + outgoing_dxo = DXO(data_kind=incoming_dxo.data_kind, data=np_data, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: 1}) + +Additionally, the receiving side needs to convert the bytes back to numpy array with codes similar to the following: + +.. code:: python + + for k in np_data: + self.log_info(fl_ctx, f"converting and adding delta for {k=}") + np_data[k] = np.load(io.BytesIO(np_data[k])) + + +Configuration +******************* +We measured the bandwidth between the server and west-us-2 client. It took around 2300 seconds to transfer the model from the client to the server and around 2000 seconds from the server to the client. +On the ap-south-1 client, it took about 11000 seconds from the client to the server and 11500 seconds from the server to the client. We updated the following values to accommodate such differences. + + - streaming_read_timeout to 3000 + - streaming_ack_wait to 6000 + - communication_timeout to 6000 + + +The `streaming_read_timeout` is used to check when a chunck of data is received but is not read out by the upper layer. The `streaming_ack_wait` is how long the sender should wait for acknowledgement returned by the receiver for one chunck. + + +The `communication_timeout` is used on three consecutive stages for a single request and response. When sending a large request (submit_update), the sender starts a timer with timeout = `communication_timeout`. +When this timer expires, the sender checks if any progress is made during this period. If yes, the sender resets the timer with the same timeout value and waits again. If not, this request and response returns with timeout. +After sending completes, the sender cancels the previous timer and starts a `remote processing` timer with timeout = `communication_timeout`. This is to wait for the first returned byte from the receiver. On +large models, the server requires much longer time to prepare the task when the clients send `get_task` requests. After receiving the first returned byte, the sender cancel the `remote processing` timer and starts +a new timer. It checks the receiving progress just like sending. + + +Since the experiment was based on hello-numpy, one of the arguments, `train_timeout` in the ScatterAndGather class had to be updated. This timeout is used to check the scheduling of training tasks. We +changed this argument to 60000 for this experiment. + +Memory Usage +******************* +During the experiment, the server could use more than 512GB, ie 128GB * 2 clients * 2 (model and runtime space). The following figure shows the CPU and memory usage of the server. + +.. image:: ../resources/128GB_server.png + :height: 350px + +Although most of the time, the server was using less than 512GB, there were a few peaks that reached 700GB or more. + +The followings are clients, west-us-2 and ap-south-1. + +.. image:: ../resources/128GB_site1.png + :height: 350px + + +.. image:: ../resources/128GB_site2.png + :height: 350px + + +The west-us-2 client, with its fast bandwidth with the server, received and sent the models in about 100 minutes and entered nearly idle state with little cpu and memory usage. Both +clients used about 256GB, ie 128GB * 2 (model and runtime space), but at the end of receiving large models and at the beginning of sending large models, these two clients required more than +378GB, ie 128GB * 3. + diff --git a/docs/resources/128GB_server.png b/docs/resources/128GB_server.png new file mode 100644 index 0000000000..d8cbbf26b4 Binary files /dev/null and b/docs/resources/128GB_server.png differ diff --git a/docs/resources/128GB_site1.png b/docs/resources/128GB_site1.png new file mode 100644 index 0000000000..02e99f3bc5 Binary files /dev/null and b/docs/resources/128GB_site1.png differ diff --git a/docs/resources/128GB_site2.png b/docs/resources/128GB_site2.png new file mode 100644 index 0000000000..494fb3068a Binary files /dev/null and b/docs/resources/128GB_site2.png differ diff --git a/examples/hello-world/step-by-step/cifar10/code/fl/executor.py b/examples/hello-world/step-by-step/cifar10/code/fl/executor.py new file mode 100644 index 0000000000..27e648997d --- /dev/null +++ b/examples/hello-world/step-by-step/cifar10/code/fl/executor.py @@ -0,0 +1,291 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +from net import Net +from torchvision import transforms + +from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReservedKey, ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.abstract.model import make_model_learnable, model_learnable_to_dxo +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_opt.pt.model_persistence_format_manager import PTModelPersistenceFormatManager + + +class CIFAR10Executor(Executor): + def __init__( + self, + epochs: int = 2, + lr: float = 1e-2, + momentum: float = 0.9, + batch_size: int = 4, + num_workers: int = 1, + dataset_path: str = "/tmp/nvflare/data/cifar10", + model_path: str = "/tmp/nvflare/data/cifar10/cifar_net.pth", + device="cuda:0", + pre_train_task_name=AppConstants.TASK_GET_WEIGHTS, + train_task_name=AppConstants.TASK_TRAIN, + submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL, + validate_task_name=AppConstants.TASK_VALIDATION, + exclude_vars=None, + ): + """Cifar10 Executor handles train, validate, and submit_model tasks. During train_task, it trains a + simple network on CIFAR10 dataset. For validate task, it evaluates the input model on the test set. + For submit_model task, it sends the locally trained model (if present) to the server. + + Args: + epochs (int, optional): Epochs. Defaults to 2 + lr (float, optional): Learning rate. Defaults to 0.01 + momentum (float, optional): Momentum. Defaults to 0.9 + batch_size: batch size for training and validation. + num_workers: number of workers for data loaders. + dataset_path: path to dataset + model_path: path to save model + device: (optional) We change to use GPU to speed things up. if you want to use CPU, change DEVICE="cpu" + pre_train_task_name: Task name for pre train task, i.e., sending initial model weights. + train_task_name (str, optional): Task name for train task. Defaults to "train". + submit_model_task_name (str, optional): Task name for submit model. Defaults to "submit_model". + validate_task_name (str, optional): Task name for validate task. Defaults to "validate". + exclude_vars (list): List of variables to exclude during model loading. + """ + super().__init__() + self.epochs = epochs + self.lr = lr + self.momentum = momentum + self.batch_size = batch_size + self.num_workers = num_workers + self.dataset_path = dataset_path + self.model_path = model_path + + self.pre_train_task_name = pre_train_task_name + self.train_task_name = train_task_name + self.submit_model_task_name = submit_model_task_name + self.validate_task_name = validate_task_name + self.device = device + self.exclude_vars = exclude_vars + + self.train_dataset = None + self.valid_dataset = None + self.train_loader = None + self.valid_loader = None + self.net = None + self.optimizer = None + self.criterion = None + self.persistence_manager = None + self.best_acc = 0.0 + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self.initialize() + + def initialize(self): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + self.trainset = torchvision.datasets.CIFAR10( + root=self.dataset_path, train=True, download=True, transform=transform + ) + self.trainloader = torch.utils.data.DataLoader( + self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers + ) + self._n_iterations = len(self.trainloader) + + self.testset = torchvision.datasets.CIFAR10( + root=self.dataset_path, train=False, download=True, transform=transform + ) + self.testloader = torch.utils.data.DataLoader( + self.testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers + ) + + self.net = Net() + + self.criterion = nn.CrossEntropyLoss() + self.optimizer = optim.SGD(self.net.parameters(), lr=self.lr, momentum=self.momentum) + + self._default_train_conf = {"train": {"model": type(self.net).__name__}} + self.persistence_manager = PTModelPersistenceFormatManager( + data=self.net.state_dict(), default_train_conf=self._default_train_conf + ) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + try: + if task_name == self.pre_train_task_name: + # Get the new state dict and send as weights + return self._get_model_weights() + if task_name == self.train_task_name: + # Get model weights + try: + dxo = from_shareable(shareable) + except: + self.log_error(fl_ctx, "Unable to extract dxo from shareable.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Ensure data kind is weights. + if not dxo.data_kind == DataKind.WEIGHTS: + self.log_error(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Convert weights to tensor. Run training + torch_weights = {k: torch.as_tensor(v) for k, v in dxo.data.items()} + self._local_train(fl_ctx, torch_weights) + + # Check the abort_signal after training. + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + # Save the local model after training. + self._save_local_model(fl_ctx) + + # Get the new state dict and send as weights + return self._get_model_weights() + if task_name == self.validate_task_name: + model_owner = "?" + try: + try: + dxo = from_shareable(shareable) + except: + self.log_error(fl_ctx, "Error in extracting dxo from shareable.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Ensure data_kind is weights. + if not dxo.data_kind == DataKind.WEIGHTS: + self.log_exception(fl_ctx, f"DXO is of type {dxo.data_kind} but expected type WEIGHTS.") + return make_reply(ReturnCode.BAD_TASK_DATA) + + # Extract weights and ensure they are tensor. + model_owner = shareable.get_header(AppConstants.MODEL_OWNER, "?") + weights = {k: torch.as_tensor(v, device=self.device) for k, v in dxo.data.items()} + + # Get validation accuracy + val_accuracy = self._local_validate(fl_ctx, weights) + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + self.log_info( + fl_ctx, + f"Accuracy when validating {model_owner}'s model on" + f" {fl_ctx.get_identity_name()}" + f"s data: {val_accuracy}", + ) + + dxo = DXO(data_kind=DataKind.METRICS, data={"val_acc": val_accuracy}) + return dxo.to_shareable() + except: + self.log_exception(fl_ctx, f"Exception in validating model from {model_owner}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + elif task_name == self.submit_model_task_name: + # Load local model + ml = self._load_local_model(fl_ctx) + + # Get the model parameters and create dxo from it + dxo = model_learnable_to_dxo(ml) + return dxo.to_shareable() + else: + return make_reply(ReturnCode.TASK_UNKNOWN) + except Exception as e: + self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + def _get_model_weights(self) -> Shareable: + # Get the new state dict and send as weights + weights = {k: v.cpu().numpy() for k, v in self.net.state_dict().items()} + + outgoing_dxo = DXO( + data_kind=DataKind.WEIGHTS, data=weights, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: self._n_iterations} + ) + return outgoing_dxo.to_shareable() + + def _local_train(self, fl_ctx, input_weights): + self.net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + self.net.to(self.device) + + for epoch in range(self.epochs): # loop over the dataset multiple times + + running_loss = 0.0 + for i, data in enumerate(self.trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + # (optional) use GPU to speed things up + inputs, labels = data[0].to(self.device), data[1].to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.criterion(outputs, labels) + loss.backward() + self.optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + self.log_info(fl_ctx, f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + self.log_info(fl_ctx, "Finished Training") + + def _local_validate(self, fl_ctx, input_weights): + self.net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + self.net.to(self.device) + + correct = 0 + total = 0 + # since we're not training, we don't need to calculate the gradients for our outputs + with torch.no_grad(): + for data in self.testloader: + # (optional) use GPU to speed things up + images, labels = data[0].to(self.device), data[1].to(self.device) + # calculate outputs by running images through the network + outputs = self.net(images) + # the class with the highest energy is what we choose as prediction + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + val_accuracy = 100 * correct // total + self.log_info(fl_ctx, f"Accuracy of the network on the 10000 test images: {val_accuracy} %") + return val_accuracy + + def _save_local_model(self, fl_ctx: FLContext): + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM)) + models_dir = os.path.join(run_dir, "models") + if not os.path.exists(models_dir): + os.makedirs(models_dir) + + ml = make_model_learnable(self.net.state_dict(), {}) + self.persistence_manager.update(ml) + torch.save(self.persistence_manager.to_persistence_dict(), self.model_path) + + def _load_local_model(self, fl_ctx: FLContext): + run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM)) + models_dir = os.path.join(run_dir, "models") + if not os.path.exists(models_dir): + return None + + self.persistence_manager = PTModelPersistenceFormatManager( + data=torch.load(self.model_path), default_train_conf=self._default_train_conf + ) + ml = self.persistence_manager.to_model_learnable(exclude_vars=self.exclude_vars) + return ml diff --git a/examples/hello-world/step-by-step/cifar10/code/fl/model_learner.py b/examples/hello-world/step-by-step/cifar10/code/fl/model_learner.py new file mode 100644 index 0000000000..c500265367 --- /dev/null +++ b/examples/hello-world/step-by-step/cifar10/code/fl/model_learner.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +from net import Net +from torchvision import transforms + +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType +from nvflare.app_common.abstract.model_learner import ModelLearner +from nvflare.app_common.app_constant import ModelName + + +class CIFAR10ModelLearner(ModelLearner): + def __init__( + self, + epochs: int = 2, + lr: float = 1e-2, + momentum: float = 0.9, + batch_size: int = 4, + num_workers: int = 1, + dataset_path: str = "/tmp/nvflare/data/cifar10", + model_path: str = "/tmp/nvflare/data/cifar10/cifar_net.pth", + device: str = "cuda:0", + ): + """CIFAR-10 Trainer. + + Args: + epochs: the number of training epochs for a round. Defaults to 1. + lr: local learning rate. Float number. Defaults to 1e-2. + momentum (float, optional): Momentum. Defaults to 0.9 + batch_size: batch size for training and validation. + num_workers: number of workers for data loaders. + dataset_path: path to dataset + model_path: path to save model + device: (optional) We change to use GPU to speed things up. if you want to use CPU, change DEVICE="cpu" + + Returns: + an FLModel with the updated local model differences after running `train()`, the metrics after `validate()`, + or the best local model depending on the specified task. + """ + super().__init__() + self.epochs = epochs + self.lr = lr + self.momentum = momentum + self.batch_size = batch_size + self.num_workers = num_workers + self.dataset_path = dataset_path + self.model_path = model_path + + self.train_dataset = None + self.train_loader = None + self.valid_dataset = None + self.valid_loader = None + + self.net = None + self.optimizer = None + self.criterion = None + self.device = device + self.best_acc = 0.0 + + def initialize(self): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + self.trainset = torchvision.datasets.CIFAR10( + root=self.dataset_path, train=True, download=True, transform=transform + ) + self.trainloader = torch.utils.data.DataLoader( + self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers + ) + + self.testset = torchvision.datasets.CIFAR10( + root=self.dataset_path, train=False, download=True, transform=transform + ) + self.testloader = torch.utils.data.DataLoader( + self.testset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers + ) + + self.net = Net() + self.criterion = nn.CrossEntropyLoss() + self.optimizer = optim.SGD(self.net.parameters(), lr=self.lr, momentum=self.momentum) + + def get_model(self, model_name: str) -> Union[str, FLModel]: + # Retrieve the best local model saved during training. + if model_name == ModelName.BEST_MODEL: + try: + model_data = torch.load(self.model_path, map_location="cpu") + np_model_data = {k: v.cpu().numpy() for k, v in model_data.items()} + + return FLModel(params_type=ParamsType.FULL, params=np_model_data) + except Exception as e: + raise ValueError("Unable to load best model") from e + else: + raise ValueError(f"Unknown model_type: {model_name}") # Raised errors are caught in LearnerExecutor class. + + def train(self, model: FLModel) -> Union[str, FLModel]: + self.info(f"Current/Total Round: {self.current_round + 1}/{self.total_rounds}") + self.info(f"Client identity: {self.site_name}") + + pt_input_params = {k: torch.as_tensor(v) for k, v in model.params.items()} + self._local_train(pt_input_params) + + pt_output_params = {k: torch.as_tensor(v) for k, v in self.net.cpu().state_dict().items()} + accuracy = self._local_validate(pt_output_params) + + if accuracy > self.best_acc: + self.best_acc = accuracy + torch.save(self.net.state_dict(), self.model_path) + + np_output_params = {k: v.cpu().numpy() for k, v in self.net.cpu().state_dict().items()} + return FLModel( + params=np_output_params, + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": 2 * len(self.trainloader)}, + ) + + def _local_train(self, input_weights): + self.net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + self.net.to(self.device) + + for epoch in range(self.epochs): # loop over the dataset multiple times + + running_loss = 0.0 + for i, data in enumerate(self.trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + # (optional) use GPU to speed things up + inputs, labels = data[0].to(self.device), data[1].to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs = self.net(inputs) + loss = self.criterion(outputs, labels) + loss.backward() + self.optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + self.info(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + self.info("Finished Training") + + def validate(self, model: FLModel) -> Union[str, FLModel]: + pt_params = {k: torch.as_tensor(v) for k, v in model.params.items()} + val_accuracy = self._local_validate(pt_params) + + return FLModel(metrics={"val_accuracy": val_accuracy}) + + def _local_validate(self, input_weights): + self.net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + self.net.to(self.device) + + correct = 0 + total = 0 + # since we're not training, we don't need to calculate the gradients for our outputs + with torch.no_grad(): + for data in self.testloader: + # (optional) use GPU to speed things up + images, labels = data[0].to(self.device), data[1].to(self.device) + # calculate outputs by running images through the network + outputs = self.net(images) + # the class with the highest energy is what we choose as prediction + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + val_accuracy = 100 * correct // total + self.info(f"Accuracy of the network on the 10000 test images: {val_accuracy} %") + return val_accuracy diff --git a/examples/hello-world/step-by-step/cifar10/sag_executor/sag_executor.ipynb b/examples/hello-world/step-by-step/cifar10/sag_executor/sag_executor.ipynb new file mode 100644 index 0000000000..70764edb41 --- /dev/null +++ b/examples/hello-world/step-by-step/cifar10/sag_executor/sag_executor.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "514c47e2-420d-4af4-9bf0-cac337c51c39", + "metadata": {}, + "source": [ + "# FedAvg with SAG workflow using Executor\n", + "\n", + "In this example, we will demonstrate the FegAvg SAG workflow using the CIFAR10 dataset using an Executor. \n", + "\n", + "While the previous example [FedAvg with SAG workflow](../sag/sag.ipynb#title) utilized the Client API, here we will demonstrate how to convert the original training code into a Executor trainer, showcase its capabilities, and recommend the best use cases.\n", + "\n", + "For an overview on Federated Averaging and SAG, see the section from the previous example: [Understanding FedAvg and SAG](../sag/sag.ipynb#sag)\n", + "\n", + "## Executor\n", + "\n", + "An `Executor` in FLARE is an FLComponent for clients used for executing tasks, wherein the `execute` method receives and returns a `Shareable` object given a task name.\n", + "\n", + "Key Concepts:\n", + "- Executor is a client-side FLComponent for executing tasks\n", + "- Produces `Shareable` from input `Shareable` and handles `DXO` object conversion for standardized data passing\n", + "- Directly uses FLARE-specific communication concepts, and as such serves as the basis of higher level learning APIs made to abstract these concepts away\n", + "\n", + "See the [documentation](https://nvflare.readthedocs.io/en/main/programming_guide/executor.html#executor) for more information about Executors and other FLARE-specific constructs.\n", + "\n", + "### When to use Executors\n", + "\n", + "The Executor is best used when implementing tasks and logic that do not fit the standard learning methods of higher level APIs such as the ModelLearner or Client API. In this example, in addition to the `train`, `validate`, and `submit_model` tasks, we also introduce the `get_weights` task. This pretrain task allows us to perform the `InitializeGlobalWeights` workflow, which would otherwise not be supported.\n", + "\n", + "## Converting DL training code to FL Executor training code\n", + "We will use the original [Training a Classifer](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) example\n", + "in PyTorch as our base [DL code](../code/dl/train.py).\n", + "\n", + "In order to transform the existing PyTorch classifier training code into Federated Classifer training code, we must restructure our code to implement tasks to execute, as well as handle the data exchange formats. The converted code can be found at [FL Executor code](../code/fl/executor.py).\n", + "\n", + "Key changes:\n", + "- Encapsulate the original DL train and validate code inside `local_train()` and `local_validate()` and the dataset and PyTorch training utilities in `initialize()`\n", + "- Implement `execute` function to handle `get_weights`, `train`, `validate`, and `submit_model` tasks\n", + "- Process incoming and outgoing `Shareable` objects, and converting to and from `DXO` objects\n", + "- Implement `_save_local_model()` and `_load_local_model()` using the `PTPersistenceManager` to handle `ModelLearnable` object and manage the format for PyTorch model persistence.\n", + "\n", + "```\n", + "def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:\n", + " try:\n", + " if task_name == self.pre_train_task_name:\n", + " # Get the new state dict and send as weights\n", + " return self._get_model_weights()\n", + " if task_name == self.train_task_name:\n", + " # Get model weights\n", + " try:\n", + " dxo = from_shareable(shareable)\n", + " except:\n", + " self.log_error(fl_ctx, \"Unable to extract dxo from shareable.\")\n", + " return make_reply(ReturnCode.BAD_TASK_DATA)\n", + "\n", + " # Ensure data kind is weights.\n", + " if not dxo.data_kind == DataKind.WEIGHTS:\n", + " self.log_error(fl_ctx, f\"data_kind expected WEIGHTS but got {dxo.data_kind} instead.\")\n", + " return make_reply(ReturnCode.BAD_TASK_DATA)\n", + "\n", + " # Convert weights to tensor. Run training\n", + " torch_weights = {k: torch.as_tensor(v) for k, v in dxo.data.items()}\n", + " self._local_train(fl_ctx, torch_weights)\n", + "\n", + " # Check the abort_signal after training.\n", + " if abort_signal.triggered:\n", + " return make_reply(ReturnCode.TASK_ABORTED)\n", + "\n", + " # Save the local model after training.\n", + " self._save_local_model(fl_ctx)\n", + "\n", + " # Get the new state dict and send as weights\n", + " return self._get_model_weights()\n", + " if task_name == self.validate_task_name:\n", + " model_owner = \"?\"\n", + " try:\n", + " try:\n", + " dxo = from_shareable(shareable)\n", + " except:\n", + " self.log_error(fl_ctx, \"Error in extracting dxo from shareable.\")\n", + " return make_reply(ReturnCode.BAD_TASK_DATA)\n", + "\n", + " # Ensure data_kind is weights.\n", + " if not dxo.data_kind == DataKind.WEIGHTS:\n", + " self.log_exception(fl_ctx, f\"DXO is of type {dxo.data_kind} but expected type WEIGHTS.\")\n", + " return make_reply(ReturnCode.BAD_TASK_DATA)\n", + "\n", + " # Extract weights and ensure they are tensor.\n", + " model_owner = shareable.get_header(AppConstants.MODEL_OWNER, \"?\")\n", + " weights = {k: torch.as_tensor(v, device=self.device) for k, v in dxo.data.items()}\n", + "\n", + " # Get validation accuracy\n", + " val_accuracy = self._local_validate(fl_ctx, weights)\n", + " if abort_signal.triggered:\n", + " return make_reply(ReturnCode.TASK_ABORTED)\n", + "\n", + " self.log_info(\n", + " fl_ctx,\n", + " f\"Accuracy when validating {model_owner}'s model on\"\n", + " f\" {fl_ctx.get_identity_name()}\"\n", + " f\"s data: {val_accuracy}\",\n", + " )\n", + "\n", + " dxo = DXO(data_kind=DataKind.METRICS, data={\"val_acc\": val_accuracy})\n", + " return dxo.to_shareable()\n", + " except:\n", + " self.log_exception(fl_ctx, f\"Exception in validating model from {model_owner}\")\n", + " return make_reply(ReturnCode.EXECUTION_EXCEPTION)\n", + " elif task_name == self.submit_model_task_name:\n", + " # Load local model\n", + " ml = self._load_local_model(fl_ctx)\n", + "\n", + " # Get the model parameters and create dxo from it\n", + " dxo = model_learnable_to_dxo(ml)\n", + " return dxo.to_shareable()\n", + " else:\n", + " return make_reply(ReturnCode.TASK_UNKNOWN)\n", + " except Exception as e:\n", + " self.log_exception(fl_ctx, f\"Exception in simple trainer: {e}.\")\n", + " return make_reply(ReturnCode.EXECUTION_EXCEPTION)\n", + "...\n", + "```\n", + "\n", + "## Job Configuration\n", + "\n", + "Now we must install the Executor to the training client. We define our CIFAR10Executor in the client configuration, and list the implemented tasks.\n", + "\n", + "Since our CIFAR10Executor supports the get_weights, train, validate, and submit_model tasks, we can use the InitializeGlobalWeights, CrossSiteModelEval, and ScatterAndGather workflows in the server configuration.\n", + "\n", + "Let's use the Job CLI to create the job from an PyTorch Executor template:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de430380", + "metadata": {}, + "outputs": [], + "source": [ + "! nvflare job create -j /tmp/nvflare/jobs/sag_pt_executor -w sag_pt_executor -sd ../code/fl -force" + ] + }, + { + "cell_type": "markdown", + "id": "5fd8e88f", + "metadata": {}, + "source": [ + "We can take a look at the server and client configurations and make any changes as desired:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "369c5501", + "metadata": {}, + "outputs": [], + "source": [ + "! cat /tmp/nvflare/jobs/sag_pt_executor/app/config/config_fed_server.conf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d223847b", + "metadata": {}, + "outputs": [], + "source": [ + "! cat /tmp/nvflare/jobs/sag_pt_executor/app/config/config_fed_client.conf" + ] + }, + { + "cell_type": "markdown", + "id": "83cc8869", + "metadata": {}, + "source": [ + "## Prepare Data" + ] + }, + { + "cell_type": "markdown", + "id": "8f63bf0f", + "metadata": {}, + "source": [ + "Make sure the CIFAR10 dataset is downloaded with the following script:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17323f61", + "metadata": {}, + "outputs": [], + "source": [ + "! python ../data/download.py" + ] + }, + { + "cell_type": "markdown", + "id": "d71f3c9f-8185-47d3-8658-40f7b16699c5", + "metadata": {}, + "source": [ + "## Run the Job\n", + "\n", + "Now we can run the job with the simulator:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70738539-3df6-4779-831f-0a1375d6aabf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "! nvflare simulator /tmp/nvflare/jobs/sag_pt_executor -w /tmp/nvflare/sag_pt_executor -t 2 -n 2 " + ] + }, + { + "cell_type": "markdown", + "id": "48271064", + "metadata": {}, + "source": [ + "For additional resources, take a look at the various other executors with different use cases in the app_common, app_opt, and examples folder." + ] + }, + { + "cell_type": "markdown", + "id": "9bef3134", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/hello-world/step-by-step/cifar10/sag_model_learner/sag_model_learner.ipynb b/examples/hello-world/step-by-step/cifar10/sag_model_learner/sag_model_learner.ipynb new file mode 100644 index 0000000000..30db86c890 --- /dev/null +++ b/examples/hello-world/step-by-step/cifar10/sag_model_learner/sag_model_learner.ipynb @@ -0,0 +1,238 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "514c47e2-420d-4af4-9bf0-cac337c51c39", + "metadata": {}, + "source": [ + "# FedAvg with SAG workflow using Model Learner\n", + "\n", + "In this example, we will demonstrate the FegAvg SAG workflow using the CIFAR10 dataset using the ModelLearner API. \n", + "\n", + "While the previous example [FedAvg with SAG workflow](../sag/sag.ipynb#title) utilized the Client API, here we will demonstrate how to convert the original training code into a ModelLearner trainer, showcase its capabilities, and recommend the best use cases.\n", + "\n", + "For an overview on Federated Averaging and SAG, see the section from the previous example: [Understanding FedAvg and SAG](../sag/sag.ipynb#sag)\n", + "\n", + "## ModelLearner\n", + "\n", + "The main goal of the ModelLearner is to make it easier to write learning logic by minimizing FLARE specific concepts that the user is exposed to. The ModelLearner defines familiar learning functions for training and validation, and uses the FLModel object for transferring learning information.\n", + "\n", + "Key Concepts:\n", + "- Learning\n", + " - `FLModel` object defines structure to containe essential information about the learning task, such as `params`, `metrics`, `meta`, etc.\n", + " - learning logic implemented in `train()` and `validate` methods, which both receive and send an `FLModel` object\n", + " - return requested model via `get_model()`\n", + "- Lifecycle\n", + " - `initialize` for logic before learning job start and `finalize` for once learning job is finished\n", + " - abort gracefully with `abort()` or `is_aborted()`\n", + "- Convenience \n", + " - various logging methods such as `info`, `debug`, `error`, etc.\n", + " - contextual information availabled in learner\n", + "\n", + "\n", + "Here are the full definitions of the APIs for the [ModelLearner](https://github.com/NVIDIA/NVFlare/blob/dev/nvflare/app_common/abstract/model_learner.py) and [FLModel](https://github.com/NVIDIA/NVFlare/blob/dev/nvflare/app_common/abstract/fl_model.py).\n", + "\n", + "### When to use ModelLearner\n", + "\n", + "The ModelLearner is best used when working with standard machine learning code that can fit well into the train and validate methods and can be easily adapated to the ModelLearner structure. This allows for the separation of FLARE specific communication constructs from the machine learning specific tasks, and provides the FLModel object for data transfer. \n", + "\n", + "On the otherhand, if the user would rather not adapt the code structure, we recommend using the [Client API](https://github.com/NVIDIA/NVFlare/blob/main/examples/hello-world/ml-to-fl/README.md) for even simpler conversion to FL code at the cost of losing some convenience functionalities.\n", + "\n", + "Finally, if the user wishes to implement something more specific that is not supported by the ModelLearner, we recommend writing an Executor which gives greater freedom for defining logic and tasks. The main tradeoff is this requires the use of more FLARE concepts such as FLContext, Shareable, DXO, etc.\n", + "\n", + "\n", + "## Converting DL training code to FL ModelLearner training code\n", + "We will use the original [Training a Classifer](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) example\n", + "in PyTorch as our base [DL code](../code/dl/train.py).\n", + "\n", + "With the FLARE ModelLearner API, we need to transform the existing PyTorch classifer into a Federated classifer by restructuring our code to subclass ModelLearner, and implementing the required methods. The converted code can be found at [FL ModelLearner code](../code/fl/model_learner.py).\n", + "\n", + "Key Changes:\n", + "- Subclass ModelLearner with appropriate init args\n", + "- Encapsulate the original DL train and validate code inside `local_train()` and `local_validate()` and the dataset and PyTorch training utilities in `initialize()`\n", + "- Implement the `train()` and `validate()` methods by wrapping the local learning methods and processing and returning `FLModel`\n", + "- Implement `get_model()` method to load and return best local model, so it can then be sent to other sites for validation (via the cross-site evaluation workflow)\n", + "\n", + "```\n", + "def get_model(self, model_name: str) -> Union[str, FLModel]:\n", + " # Retrieve the best local model saved during training.\n", + " if model_name == ModelName.BEST_MODEL:\n", + " try:\n", + " model_data = torch.load(self.model_path, map_location=\"cpu\")\n", + " np_model_data = {k: v.cpu().numpy() for k, v in model_data.items()}\n", + "\n", + " return FLModel(params_type=ParamsType.FULL, params=np_model_data)\n", + " except Exception as e:\n", + " raise ValueError(\"Unable to load best model\") from e\n", + " else:\n", + " raise ValueError(f\"Unknown model_type: {model_name}\") # Raised errors are caught in LearnerExecutor class.\n", + "\n", + "def train(self, model: FLModel) -> Union[str, FLModel]:\n", + " self.info(f\"Current/Total Round: {self.current_round + 1}/{self.total_rounds}\")\n", + " self.info(f\"Client identity: {self.site_name}\")\n", + "\n", + " pt_input_params = {k: torch.as_tensor(v) for k, v in model.params.items()}\n", + " self._local_train(pt_input_params)\n", + "\n", + " pt_output_params = {k: torch.as_tensor(v) for k, v in self.net.cpu().state_dict().items()}\n", + " accuracy = self._local_validate(pt_output_params)\n", + "\n", + " if accuracy > self.best_acc:\n", + " self.best_acc = accuracy\n", + " torch.save(self.net.state_dict(), self.model_path)\n", + "\n", + " np_output_params = {k: v.cpu().numpy() for k, v in self.net.cpu().state_dict().items()}\n", + " return FLModel(\n", + " params=np_output_params,\n", + " metrics={\"accuracy\": accuracy},\n", + " meta={\"NUM_STEPS_CURRENT_ROUND\": 2 * len(self.trainloader)},\n", + " )\n", + "\n", + "def validate(self, model: FLModel) -> Union[str, FLModel]:\n", + " pt_params = {k: torch.as_tensor(v) for k, v in model.params.items()}\n", + " val_accuracy = self._local_validate(pt_params)\n", + "\n", + " return FLModel(metrics={\"val_accuracy\": val_accuracy})\n", + "\n", + "...\n", + " \n", + "```\n", + "\n", + "## Job Configuration\n", + "\n", + "Now we must install the ModelLearner to the training client. We use the predefined `ModelLearnerExecutor`, which handles setting up the Learner and executing the tasks using the ModelLearner methods. In the client configuration, the `learner_id` of the `ModelLearnerExecutor` is mapped to the `id` of the ModelLearner trainer component that we implemented.\n", + "\n", + "Let's use the Job CLI to create the job from a ModelLearner template:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de430380", + "metadata": {}, + "outputs": [], + "source": [ + "! nvflare job create -j /tmp/nvflare/jobs/sag_pt_model_learner -w sag_pt_model_learner -sd ../code/fl -force" + ] + }, + { + "cell_type": "markdown", + "id": "5fd8e88f", + "metadata": {}, + "source": [ + "We can take a look at the server and client configurations and make any changes as desired:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "369c5501", + "metadata": {}, + "outputs": [], + "source": [ + "! cat /tmp/nvflare/jobs/sag_pt_model_learner/app/config/config_fed_server.conf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d223847b", + "metadata": {}, + "outputs": [], + "source": [ + "! cat /tmp/nvflare/jobs/sag_pt_model_learner/app/config/config_fed_client.conf" + ] + }, + { + "cell_type": "markdown", + "id": "edf82dac", + "metadata": {}, + "source": [ + "Ensure that our ModelLearner trainer code is correctly installed with the ModelLearnerExecutor. Also since the ModelLearnerExecutor supports the train, validate, and submit_model tasks, we can use the CrossSiteModelEval workflow in the server configuration in addition to the ScatterAndGather workflow." + ] + }, + { + "cell_type": "markdown", + "id": "83cc8869", + "metadata": {}, + "source": [ + "## Prepare Data" + ] + }, + { + "cell_type": "markdown", + "id": "8f63bf0f", + "metadata": {}, + "source": [ + "Make sure the CIFAR10 dataset is downloaded with the following script:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17323f61", + "metadata": {}, + "outputs": [], + "source": [ + "! python ../data/download.py" + ] + }, + { + "cell_type": "markdown", + "id": "d71f3c9f-8185-47d3-8658-40f7b16699c5", + "metadata": {}, + "source": [ + "## Run the Job\n", + "\n", + "Now we can run the job with the simulator:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70738539-3df6-4779-831f-0a1375d6aabf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "! nvflare simulator /tmp/nvflare/jobs/sag_pt_model_learner -w /tmp/nvflare/sag_pt_model_learner -t 2 -n 2 " + ] + }, + { + "cell_type": "markdown", + "id": "48271064", + "metadata": {}, + "source": [ + "As an additional resource, also see the [CIFAR10 examples](../../../../advanced/cifar10/README.md) for a comprehensive implementation of a PyTorch ModelLearner." + ] + }, + { + "cell_type": "markdown", + "id": "9bef3134", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/job_templates/sag_pt_executor/config_fed_client.conf b/job_templates/sag_pt_executor/config_fed_client.conf new file mode 100644 index 0000000000..6de432a017 --- /dev/null +++ b/job_templates/sag_pt_executor/config_fed_client.conf @@ -0,0 +1,26 @@ +format_version = 2 +executors = [ + { + # tasks that the defined Executor will support + tasks = [ + # pre-train task + "get_weights", + # training task + "train", + # cross-site validation tasks + "submit_model", + "validate", + ] + executor { + path = "executor.CIFAR10Executor" + args { + # see class docstring for all available args + epochs = 2 + lr = 0.001 + } + } + } +] +task_data_filters = [] +task_result_filters = [] +components = [] \ No newline at end of file diff --git a/job_templates/sag_pt_executor/config_fed_server.conf b/job_templates/sag_pt_executor/config_fed_server.conf new file mode 100644 index 0000000000..402c4d8501 --- /dev/null +++ b/job_templates/sag_pt_executor/config_fed_server.conf @@ -0,0 +1,83 @@ +format_version = 2 +server { + heart_beat_timeout = 600 +} +task_data_filters = [] +task_result_filters = [] +components = [ + { + id = "persistor" + name = "PTFileModelPersistor" + args { + model { + # path to defined PyTorch network + path = "net.Net" + args {} + } + } + } + { + id = "shareable_generator" + name = "FullModelShareableGenerator" + args {} + } + { + id = "aggregator" + name = "InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHTS" + } + } + { + id = "model_selector" + name = "IntimeModelSelector" + args {} + } + { + id = "model_locator" + name = "PTFileModelLocator" + args { + pt_persistor_id = "persistor" + } + } + { + id = "json_generator" + name = "ValidationJsonGenerator" + args {} + } +] +workflows = [ + { + id = "pre_train" + name = "InitializeGlobalWeights" + args { + task_name = "get_weights" + } + } + { + id = "scatter_gather_ctl" + name = "ScatterAndGather" + args { + min_clients = 2 + # can adjust number of Scatter-And-Gather rounds + num_rounds = 2 + start_round = 0 + wait_time_after_min_received = 10 + aggregator_id = "aggregator" + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + train_task_name = "train" + train_timeout = 0 + } + } + { + id = "cross_site_model_eval" + name = "CrossSiteModelEval" + args { + model_locator_id = "model_locator" + submit_model_timeout = 600 + validation_timeout = 6000 + cleanup_models = true + } + } +] diff --git a/job_templates/sag_pt_executor/info.conf b/job_templates/sag_pt_executor/info.conf new file mode 100644 index 0000000000..e9a2aac332 --- /dev/null +++ b/job_templates/sag_pt_executor/info.conf @@ -0,0 +1,5 @@ +{ + description = "scatter & gather workflow and cross-site evaluation with PyTorch Executor" + client_category = "Executor" + controller_type = "server" +} \ No newline at end of file diff --git a/job_templates/sag_pt_executor/info.md b/job_templates/sag_pt_executor/info.md new file mode 100644 index 0000000000..ebd0d1f25b --- /dev/null +++ b/job_templates/sag_pt_executor/info.md @@ -0,0 +1,11 @@ +# Job Template Information Card + +## sag_pt_executor + name = "sag_pt_executor" + description = "scatter & gather workflow and cross-site evaluation with PyTorch Executor" + class_name = "CIFAR10Executor" + controller_type = "server" + executor_type = "Executor" + contributor = "NVIDIA" + init_publish_date = "2023-09-29" + last_updated_date = "2023-09-29" diff --git a/job_templates/sag_pt_executor/meta.json b/job_templates/sag_pt_executor/meta.json new file mode 100644 index 0000000000..742ab6f717 --- /dev/null +++ b/job_templates/sag_pt_executor/meta.json @@ -0,0 +1,10 @@ +{ + "name": "sag_pt_executor", + "resource_spec": {}, + "min_clients": 2, + "deploy_map": { + "app": [ + "@ALL" + ] + } +} diff --git a/job_templates/sag_pt_model_learner/config_fed_client.conf b/job_templates/sag_pt_model_learner/config_fed_client.conf new file mode 100644 index 0000000000..85d3915a83 --- /dev/null +++ b/job_templates/sag_pt_model_learner/config_fed_client.conf @@ -0,0 +1,34 @@ +format_version = 2 +executors = [ + { + # tasks that the defined Executor will support + tasks = [ + # training task + "train" + # cross-site validation tasks + "submit_model" + "validate" + ] + executor { + id = "Executor" + path = "nvflare.app_common.executors.model_learner_executor.ModelLearnerExecutor" + args { + # id must match the id of the ModelLearner component + learner_id = "cifar10-learner" + } + } + } +] +task_result_filters = [] +task_data_filters = [] +components = [ + { + id = "cifar10-learner" + path = "model_learner.CIFAR10ModelLearner" + args { + # see class docstring for all available args + epochs = 2 + lr = 0.001 + } + } +] diff --git a/job_templates/sag_pt_model_learner/config_fed_server.conf b/job_templates/sag_pt_model_learner/config_fed_server.conf new file mode 100644 index 0000000000..b144ddca99 --- /dev/null +++ b/job_templates/sag_pt_model_learner/config_fed_server.conf @@ -0,0 +1,76 @@ +format_version = 2 +server { + heart_beat_timeout = 600 +} +task_data_filters = [] +task_result_filters = [] +components = [ + { + id = "persistor" + name = "PTFileModelPersistor" + args { + model { + # path to defined PyTorch network + path = "net.Net" + args {} + } + } + } + { + id = "shareable_generator" + name = "FullModelShareableGenerator" + args {} + } + { + id = "aggregator" + name = "InTimeAccumulateWeightedAggregator" + args { + expected_data_kind = "WEIGHTS" + } + } + { + id = "model_selector" + name = "IntimeModelSelector" + args {} + } + { + id = "model_locator" + name = "PTFileModelLocator" + args { + pt_persistor_id = "persistor" + } + } + { + id = "json_generator" + name = "ValidationJsonGenerator" + args {} + } +] +workflows = [ + { + id = "scatter_gather_ctl" + name = "ScatterAndGather" + args { + min_clients = 2 + # can adjust number of Scatter-And-Gather rounds + num_rounds = 2 + start_round = 0 + wait_time_after_min_received = 10 + aggregator_id = "aggregator" + persistor_id = "persistor" + shareable_generator_id = "shareable_generator" + train_task_name = "train" + train_timeout = 0 + } + } + { + id = "cross_site_model_eval" + name = "CrossSiteModelEval" + args { + model_locator_id = "model_locator" + submit_model_timeout = 600 + validation_timeout = 6000 + cleanup_models = true + } + } +] diff --git a/job_templates/sag_pt_model_learner/info.conf b/job_templates/sag_pt_model_learner/info.conf new file mode 100644 index 0000000000..c2b58b68dc --- /dev/null +++ b/job_templates/sag_pt_model_learner/info.conf @@ -0,0 +1,5 @@ +{ + description = "scatter & gather workflow and cross-site evaluation with PyTorch ModelLearner" + client_category = "ModelLearner" + controller_type = "server" +} \ No newline at end of file diff --git a/job_templates/sag_pt_model_learner/info.md b/job_templates/sag_pt_model_learner/info.md new file mode 100644 index 0000000000..bd6fb31168 --- /dev/null +++ b/job_templates/sag_pt_model_learner/info.md @@ -0,0 +1,11 @@ +# Job Template Information Card + +## sag_pt_model_learner + name = "sag_pt_model_learner" + description = "scatter & gather workflow and cross-site evaluation with PyTorch ModelLearner" + class_name = "CIFAR10ModelLearner" + controller_type = "server" + executor_type = "ModelLearner" + contributor = "NVIDIA" + init_publish_date = "2023-09-29" + last_updated_date = "2023-09-29" diff --git a/job_templates/sag_pt_model_learner/meta.json b/job_templates/sag_pt_model_learner/meta.json new file mode 100644 index 0000000000..afbc2ef140 --- /dev/null +++ b/job_templates/sag_pt_model_learner/meta.json @@ -0,0 +1,10 @@ +{ + "name": "sag_pt_model_learner", + "resource_spec": {}, + "min_clients": 2, + "deploy_map": { + "app": [ + "@ALL" + ] + } +} diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index ddea05cd24..2ecb828f22 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -181,6 +181,8 @@ class ReservedTopic(object): DO_TASK = "__do_task__" AUX_COMMAND = "__aux_command__" SYNC_RUNNER = "__sync_runner__" + JOB_HEART_BEAT = "__job_heartbeat__" + TASK_CHECK = "__task_check__" class AdminCommandNames(object): diff --git a/nvflare/apis/impl/controller.py b/nvflare/apis/impl/controller.py index 16725f7ad0..0d8ca2f88b 100644 --- a/nvflare/apis/impl/controller.py +++ b/nvflare/apis/impl/controller.py @@ -344,6 +344,11 @@ def handle_dead_job(self, client_name: str, fl_ctx: FLContext): if not self._dead_client_reports.get(client_name): self._dead_client_reports[client_name] = time.time() + def process_task_check(self, task_id: str, fl_ctx: FLContext): + with self._task_lock: + # task_id is the uuid associated with the client_task + return self._client_task_map.get(task_id, None) + def process_submission(self, client: Client, task_name: str, task_id: str, result: Shareable, fl_ctx: FLContext): """Called to process a submission from one client. diff --git a/nvflare/apis/responder.py b/nvflare/apis/responder.py index a411b30bd4..6072670886 100644 --- a/nvflare/apis/responder.py +++ b/nvflare/apis/responder.py @@ -63,6 +63,16 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul """ pass + @abstractmethod + def process_task_check(self, task_id: str, fl_ctx: FLContext): + """Called by the Engine to check whether a specified task still exists. + Args: + task_id: the id of the task + fl_ctx: the FLContext + Returns: the ClientTask object if exists; None otherwise + """ + pass + @abstractmethod def handle_dead_job(self, client_name: str, fl_ctx: FLContext): """Called by the Engine to handle the case that the job on the client is dead. diff --git a/nvflare/fuel/f3/comm_config.py b/nvflare/fuel/f3/comm_config.py index c2fa51d9b5..4a93bd1d89 100644 --- a/nvflare/fuel/f3/comm_config.py +++ b/nvflare/fuel/f3/comm_config.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging - from nvflare.fuel.f3.drivers.net_utils import MAX_PAYLOAD_SIZE from nvflare.fuel.utils.config import Config from nvflare.fuel.utils.config_service import ConfigService @@ -34,6 +32,7 @@ class VarName: SUBNET_TROUBLE_THRESHOLD = "subnet_trouble_threshold" COMM_DRIVER_PATH = "comm_driver_path" HEARTBEAT_INTERVAL = "heartbeat_interval" + USE_AIO_GRPC_VAR_NAME = "use_aio_grpc" STREAMING_CHUNK_SIZE = "streaming_chunk_size" STREAMING_ACK_WAIT = "streaming_ack_wait" STREAMING_WINDOW_SIZE = "streaming_window_size" @@ -43,10 +42,26 @@ class VarName: class CommConfigurator: + + _config_loaded = False + _configuration = None + def __init__(self): - self.logger = logging.getLogger(self.__class__.__name__) - config: Config = ConfigService.load_configuration(file_basename=_comm_config_files[0]) - self.config = None if config is None else config.to_dict() + # only load once! + if not CommConfigurator._config_loaded: + config: Config = ConfigService.load_configuration(file_basename=_comm_config_files[0]) + CommConfigurator._configuration = None if config is None else config.to_dict() + CommConfigurator._config_loaded = True + self.config = CommConfigurator._configuration + + @staticmethod + def reset(): + """Reset the configurator to allow reloading config files. + + Returns: + + """ + CommConfigurator._config_loaded = False def get_config(self): return self.config @@ -78,6 +93,9 @@ def get_comm_driver_path(self, default): def get_heartbeat_interval(self, default): return ConfigService.get_int_var(VarName.HEARTBEAT_INTERVAL, self.config, default=default) + def use_aio_grpc(self, default): + return ConfigService.get_bool_var(VarName.USE_AIO_GRPC_VAR_NAME, self.config, default) + def get_streaming_chunk_size(self, default): return ConfigService.get_int_var(VarName.STREAMING_CHUNK_SIZE, self.config, default=default) diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index 2872e3b7b9..c93dffed93 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -34,6 +34,7 @@ from .base_driver import BaseDriver from .driver_params import DriverCap, DriverParams from .grpc.streamer_pb2 import Frame +from .grpc.utils import get_grpc_client_credentials, get_grpc_server_credentials, use_aio_grpc from .net_utils import MAX_FRAME_SIZE, get_address, get_tcp_urls, ssl_required GRPC_DEFAULT_OPTIONS = [ @@ -68,11 +69,18 @@ def __init__(self, aio_ctx: AioContext, connector: ConnectorInfo, conn_props: di def get_conn_properties(self) -> dict: return self.conn_props + async def _abort(self): + try: + self.context.abort(grpc.StatusCode.CANCELLED, "service closed") + except: + # ignore exception (if any) when aborting + pass + def close(self): self.closing = True with self.lock: if self.context: - self.aio_ctx.run_coro(self.context.abort(grpc.StatusCode.CANCELLED, "service closed")) + self.aio_ctx.run_coro(self._abort()) self.context = None if self.channel: self.aio_ctx.run_coro(self.channel.close()) @@ -197,20 +205,18 @@ def __init__(self, driver, connector, aio_ctx: AioContext, options, conn_ctx: _C servicer = Servicer(self, aio_ctx) add_StreamerServicer_to_server(servicer, self.grpc_server) params = connector.params - host = params.get(DriverParams.HOST.value) - if not host: - host = "0.0.0.0" - port = int(params.get(DriverParams.PORT.value)) - addr = f"{host}:{port}" + addr = get_address(params) try: self.logger.debug(f"SERVER: connector params: {params}") secure = ssl_required(params) if secure: - credentials = AioGrpcDriver.get_grpc_server_credentials(params) + credentials = get_grpc_server_credentials(params) self.grpc_server.add_secure_port(addr, server_credentials=credentials) + self.logger.info(f"added secure port at {addr}") else: self.grpc_server.add_insecure_port(addr) + self.logger.info(f"added insecure port at {addr}") except Exception as ex: conn_ctx.error = f"cannot listen on {addr}: {type(ex)}: {secure_format_exception(ex)}" self.logger.debug(conn_ctx.error) @@ -251,7 +257,10 @@ def __init__(self): @staticmethod def supported_transports() -> List[str]: - return ["grpc", "grpcs"] + if use_aio_grpc(): + return ["grpc", "grpcs"] + else: + return ["agrpc", "agrpcs"] @staticmethod def capabilities() -> Dict[str, Any]: @@ -295,10 +304,12 @@ async def _start_connect(self, connector: ConnectorInfo, aio_ctx: AioContext, co secure = ssl_required(params) if secure: grpc_channel = grpc.aio.secure_channel( - address, options=self.options, credentials=self.get_grpc_client_credentials(params) + address, options=self.options, credentials=get_grpc_client_credentials(params) ) + self.logger.info(f"created secure channel at {address}") else: grpc_channel = grpc.aio.insecure_channel(address, options=self.options) + self.logger.info(f"created insecure channel at {address}") async with grpc_channel as channel: self.logger.debug(f"CLIENT: connected to {address}") @@ -374,38 +385,9 @@ def shutdown(self): def get_urls(scheme: str, resources: dict) -> (str, str): secure = resources.get(DriverParams.SECURE) if secure: - scheme = "grpcs" + if use_aio_grpc(): + scheme = "grpcs" + else: + scheme = "agrpcs" return get_tcp_urls(scheme, resources) - - @staticmethod - def get_grpc_client_credentials(params: dict): - - root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value)) - cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_CERT)) - private_key = AioGrpcDriver.read_file(params.get(DriverParams.CLIENT_KEY)) - - return grpc.ssl_channel_credentials( - certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert - ) - - @staticmethod - def get_grpc_server_credentials(params: dict): - - root_cert = AioGrpcDriver.read_file(params.get(DriverParams.CA_CERT.value)) - cert_chain = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_CERT)) - private_key = AioGrpcDriver.read_file(params.get(DriverParams.SERVER_KEY)) - - return grpc.ssl_server_credentials( - [(private_key, cert_chain)], - root_certificates=root_cert, - require_client_auth=True, - ) - - @staticmethod - def read_file(file_name: str): - if not file_name: - return None - - with open(file_name, "rb") as f: - return f.read() diff --git a/nvflare/fuel/f3/drivers/grpc/qq.py b/nvflare/fuel/f3/drivers/grpc/qq.py new file mode 100644 index 0000000000..ca0eeb25f2 --- /dev/null +++ b/nvflare/fuel/f3/drivers/grpc/qq.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import queue + + +class QueueClosed(Exception): + pass + + +class QQ: + def __init__(self): + self.q = queue.Queue() + self.closed = False + self.logger = logging.getLogger(self.__class__.__name__) + + def close(self): + self.closed = True + + def append(self, i): + if self.closed: + raise QueueClosed("queue stopped") + self.q.put_nowait(i) + + def __iter__(self): + return self + + def __next__(self): + if self.closed: + raise StopIteration() + while True: + try: + return self.q.get(block=True, timeout=0.1) + except queue.Empty: + if self.closed: + self.logger.debug("Queue closed - stop iteration") + raise StopIteration() + except Exception as e: + self.logger.error(f"queue exception {type(e)}") + raise e diff --git a/nvflare/fuel/f3/drivers/grpc/utils.py b/nvflare/fuel/f3/drivers/grpc/utils.py new file mode 100644 index 0000000000..d95bb8138a --- /dev/null +++ b/nvflare/fuel/f3/drivers/grpc/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import grpc + +from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.drivers.driver_params import DriverParams + + +def use_aio_grpc(): + configurator = CommConfigurator() + return configurator.use_aio_grpc(default=True) + + +def get_grpc_client_credentials(params: dict): + root_cert = _read_file(params.get(DriverParams.CA_CERT.value)) + cert_chain = _read_file(params.get(DriverParams.CLIENT_CERT)) + private_key = _read_file(params.get(DriverParams.CLIENT_KEY)) + return grpc.ssl_channel_credentials( + certificate_chain=cert_chain, private_key=private_key, root_certificates=root_cert + ) + + +def get_grpc_server_credentials(params: dict): + root_cert = _read_file(params.get(DriverParams.CA_CERT.value)) + cert_chain = _read_file(params.get(DriverParams.SERVER_CERT)) + private_key = _read_file(params.get(DriverParams.SERVER_KEY)) + + return grpc.ssl_server_credentials( + [(private_key, cert_chain)], + root_certificates=root_cert, + require_client_auth=True, + ) + + +def _read_file(file_name: str): + if not file_name: + return None + + with open(file_name, "rb") as f: + return f.read() diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py new file mode 100644 index 0000000000..6ef5711aec --- /dev/null +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -0,0 +1,288 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from concurrent import futures +from typing import Any, Dict, List, Union + +import grpc + +from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.comm_error import CommError +from nvflare.fuel.f3.connection import Connection +from nvflare.fuel.f3.drivers.driver import ConnectorInfo +from nvflare.fuel.f3.drivers.grpc.streamer_pb2_grpc import ( + StreamerServicer, + StreamerStub, + add_StreamerServicer_to_server, +) +from nvflare.fuel.utils.obj_utils import get_logger +from nvflare.security.logging import secure_format_exception + +from .base_driver import BaseDriver +from .driver_params import DriverCap, DriverParams +from .grpc.qq import QQ +from .grpc.streamer_pb2 import Frame +from .grpc.utils import get_grpc_client_credentials, get_grpc_server_credentials, use_aio_grpc +from .net_utils import MAX_FRAME_SIZE, get_address, get_tcp_urls, ssl_required + +GRPC_DEFAULT_OPTIONS = [ + ("grpc.max_send_message_length", MAX_FRAME_SIZE), + ("grpc.max_receive_message_length", MAX_FRAME_SIZE), +] + + +class StreamConnection(Connection): + + seq_num = 0 + + def __init__(self, oq: QQ, connector: ConnectorInfo, conn_props: dict, side: str, context=None, channel=None): + super().__init__(connector) + self.side = side + self.oq = oq + self.closing = False + self.conn_props = conn_props + self.context = context # for server side + self.channel = channel # for client side + self.lock = threading.Lock() + self.logger = get_logger(self) + + def get_conn_properties(self) -> dict: + return self.conn_props + + def close(self): + self.closing = True + with self.lock: + self.oq.close() + if self.context: + try: + self.context.abort(grpc.StatusCode.CANCELLED, "service closed") + except: + # ignore any exception when aborting + pass + self.context = None + if self.channel: + self.channel.close() + self.channel = None + + def send_frame(self, frame: Union[bytes, bytearray, memoryview]): + try: + StreamConnection.seq_num += 1 + seq = StreamConnection.seq_num + self.logger.debug(f"{self.side}: queued frame #{seq}") + self.oq.append(Frame(seq=seq, data=bytes(frame))) + except BaseException as ex: + raise CommError(CommError.ERROR, f"Error sending frame: {ex}") + + def read_loop(self, msg_iter, q: QQ): + ct = threading.current_thread() + self.logger.debug(f"{self.side}: started read_loop in thread {ct.name}") + try: + for f in msg_iter: + if self.closing: + break + + assert isinstance(f, Frame) + self.logger.debug(f"{self.side} in {ct.name}: incoming frame #{f.seq}") + if self.frame_receiver: + self.frame_receiver.process_frame(f.data) + else: + self.logger.error(f"{self.side}: Frame receiver not registered for connection: {self.name}") + except Exception as ex: + if not self.closing: + self.logger.debug(f"{self.side}: exception {type(ex)} in read_loop") + if q: + self.logger.debug(f"{self.side}: closing queue") + q.close() + self.logger.debug(f"{self.side} in {ct.name}: done read_loop") + + def generate_output(self): + ct = threading.current_thread() + self.logger.debug(f"{self.side}: generate_output in thread {ct.name}") + for i in self.oq: + assert isinstance(i, Frame) + self.logger.debug(f"{self.side}: outgoing frame #{i.seq}") + yield i + self.logger.debug(f"{self.side}: done generate_output in thread {ct.name}") + + +class Servicer(StreamerServicer): + def __init__(self, server): + self.server = server + self.logger = get_logger(self) + + def Stream(self, request_iterator, context): + connection = None + oq = QQ() + t = None + ct = threading.current_thread() + conn_props = { + DriverParams.PEER_ADDR.value: context.peer(), + DriverParams.LOCAL_ADDR.value: get_address(self.server.connector.params), + } + cn_names = context.auth_context().get("x509_common_name") + if cn_names: + conn_props[DriverParams.PEER_CN.value] = cn_names[0].decode("utf-8") + + try: + self.logger.debug(f"SERVER started Stream CB in thread {ct.name}") + connection = StreamConnection(oq, self.server.connector, conn_props, "SERVER", context=context) + self.logger.debug(f"SERVER created connection in thread {ct.name}") + self.server.driver.add_connection(connection) + self.logger.debug(f"SERVER created read_loop thread in thread {ct.name}") + t = threading.Thread(target=connection.read_loop, args=(request_iterator, oq)) + t.start() + + # DO NOT use connection.generate_output()! + self.logger.debug(f"SERVER: generate_output in thread {ct.name}") + for i in oq: + assert isinstance(i, Frame) + self.logger.debug(f"SERVER: outgoing frame #{i.seq}") + yield i + self.logger.debug(f"SERVER: done generate_output in thread {ct.name}") + + except BaseException as ex: + self.logger.error(f"Connection closed due to error: {ex}") + finally: + if t is not None: + t.join() + if connection: + self.logger.debug(f"SERVER: closing connection {connection.name}") + self.server.driver.close_connection(connection) + self.logger.debug(f"SERVER: cleanly finished Stream CB in thread {ct.name}") + + +class Server: + def __init__( + self, + driver, + connector, + max_workers, + options, + ): + self.driver = driver + self.logger = get_logger(self) + self.connector = connector + self.grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers), options=options) + servicer = Servicer(self) + add_StreamerServicer_to_server(servicer, self.grpc_server) + + params = connector.params + addr = get_address(params) + try: + self.logger.debug(f"SERVER: connector params: {params}") + secure = ssl_required(params) + if secure: + credentials = get_grpc_server_credentials(params) + self.grpc_server.add_secure_port(addr, server_credentials=credentials) + self.logger.info(f"added secure port at {addr}") + else: + self.grpc_server.add_insecure_port(addr) + self.logger.info(f"added insecure port at {addr}") + except Exception as ex: + error = f"cannot listen on {addr}: {type(ex)}: {secure_format_exception(ex)}" + self.logger.debug(error) + + def start(self): + self.grpc_server.start() + self.grpc_server.wait_for_termination() + + def shutdown(self): + self.grpc_server.stop(grace=0.5) + + +class GrpcDriver(BaseDriver): + def __init__(self): + BaseDriver.__init__(self) + self.server = None + self.closing = False + self.max_workers = 100 + self.options = GRPC_DEFAULT_OPTIONS + self.logger = get_logger(self) + configurator = CommConfigurator() + config = configurator.get_config() + if config: + my_params = config.get("grpc") + if my_params: + self.max_workers = my_params.get("max_workers", 100) + self.options = my_params.get("options") + self.logger.debug(f"GRPC Config: max_workers={self.max_workers}, options={self.options}") + + @staticmethod + def supported_transports() -> List[str]: + if use_aio_grpc(): + return ["nagrpc", "nagrpcs"] + else: + return ["grpc", "grpcs"] + + @staticmethod + def capabilities() -> Dict[str, Any]: + return {DriverCap.SEND_HEARTBEAT.value: True, DriverCap.SUPPORT_SSL.value: True} + + def listen(self, connector: ConnectorInfo): + self.connector = connector + self.server = Server(self, connector, max_workers=self.max_workers, options=self.options) + self.server.start() + + def connect(self, connector: ConnectorInfo): + self.logger.debug("CLIENT: trying connect ...") + params = connector.params + address = get_address(params) + conn_props = {DriverParams.PEER_ADDR.value: address} + + secure = ssl_required(params) + if secure: + self.logger.debug("CLIENT: creating secure channel") + channel = grpc.secure_channel( + address, options=self.options, credentials=get_grpc_client_credentials(params) + ) + self.logger.info(f"created secure channel at {address}") + else: + self.logger.info("CLIENT: creating insecure channel") + channel = grpc.insecure_channel(address, options=self.options) + self.logger.info(f"created insecure channel at {address}") + + self.logger.debug("CLIENT: created channel") + stub = StreamerStub(channel) + self.logger.debug("CLIENT: got stub") + oq = QQ() + connection = StreamConnection(oq, connector, conn_props, "CLIENT", channel=channel) + self.add_connection(connection) + self.logger.debug("CLIENT: added connection") + try: + received = stub.Stream(connection.generate_output()) + connection.read_loop(received, oq) + except BaseException as ex: + self.logger.info(f"CLIENT: connection done: {type(ex)}") + connection.close() + self.close_connection(connection) + self.logger.info(f"CLIENT: finished connection {connection}") + + @staticmethod + def get_urls(scheme: str, resources: dict) -> (str, str): + secure = resources.get(DriverParams.SECURE) + if secure: + if use_aio_grpc(): + scheme = "nagrpcs" + else: + scheme = "grpcs" + return get_tcp_urls(scheme, resources) + + def shutdown(self): + if self.closing: + return + self.closing = True + self.close_all() + if self.server: + self.server.shutdown() diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index 3f9ae1e77d..6dd17b8979 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -79,7 +79,8 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: def get_address(params: dict) -> str: host = params.get(DriverParams.HOST.value, "0.0.0.0") port = params.get(DriverParams.PORT.value, 0) - + if not host: + host = "0.0.0.0" return f"{host}:{port}" diff --git a/nvflare/fuel/flare_api/flare_api.py b/nvflare/fuel/flare_api/flare_api.py index 4a6de97917..e3ba0f2b85 100644 --- a/nvflare/fuel/flare_api/flare_api.py +++ b/nvflare/fuel/flare_api/flare_api.py @@ -358,12 +358,8 @@ def download_job_result(self, job_id: str) -> str: self._validate_job_id(job_id) result = self._do_command(AdminCommandNames.DOWNLOAD_JOB + " " + job_id) meta = result[ResultKey.META] - download_job_id = meta.get(MetaKey.JOB_ID, None) - job_download_url = meta.get(MetaKey.JOB_DOWNLOAD_URL, None) - if not job_download_url: - return os.path.join(self.download_dir, download_job_id) - else: - return job_download_url + location = meta.get(MetaKey.LOCATION) + return location def abort_job(self, job_id: str): """Abort the specified job. diff --git a/nvflare/fuel/hci/client/cli.py b/nvflare/fuel/hci/client/cli.py index d43a9a7439..d18c54dcad 100644 --- a/nvflare/fuel/hci/client/cli.py +++ b/nvflare/fuel/hci/client/cli.py @@ -304,6 +304,13 @@ def default(self, line): self.write_stdout(f"exception occurred: {secure_format_exception(e)}") self._close_output_file() + @staticmethod + def _user_input(prompt: str) -> str: + answer = input(prompt) + + # remove leading and trailing spaces + return answer.strip() + def _do_default(self, line): args = split_to_args(line) cmd_name = args[0] @@ -360,14 +367,14 @@ def _do_default(self, line): info = CommandInfo.CONFIRM_YN if info == CommandInfo.CONFIRM_YN: - answer = input("Are you sure (y/N): ") + answer = self._user_input("Are you sure (y/N): ") answer = answer.lower() if answer != "y" and answer != "yes": return elif info == CommandInfo.CONFIRM_USER_NAME: - answer = input("Confirm with User Name: ") + answer = self._user_input("Confirm with User Name: ") if answer != self.user_name: - self.write_string("user name mismatch") + self.write_string(f"user name mismatch: {answer} != {self.user_name}") return elif info == CommandInfo.CONFIRM_PWD: pwd = getpass.getpass("Enter password to confirm: ") @@ -428,7 +435,7 @@ def cmdloop(self, intro=None): else: if self.use_rawinput: try: - line = input(self.prompt) + line = self._user_input(self.prompt) except (EOFError, ConnectionError): line = "bye" except KeyboardInterrupt: @@ -477,7 +484,7 @@ def _get_login_creds(self): elif self.credential_type == CredentialType.LOCAL_CERT: self.user_name = self.username else: - self.user_name = input("User Name: ") + self.user_name = self._user_input("User Name: ") def print_resp(self, resp: dict): """Prints the server response diff --git a/nvflare/fuel/hci/client/file_transfer.py b/nvflare/fuel/hci/client/file_transfer.py index 113e304eba..a31eb51ad7 100644 --- a/nvflare/fuel/hci/client/file_transfer.py +++ b/nvflare/fuel/hci/client/file_transfer.py @@ -430,7 +430,11 @@ def pull_folder(self, args, ctx: CommandContext): tx_path = self._tx_path(tx_id, folder_name) destination_path = os.path.join(self.download_dir, destination_name) location = self._rename_folder(tx_path, destination_path) - reply = {ProtoKey.STATUS: APIStatus.SUCCESS, ProtoKey.DETAILS: f"content downloaded to {location}"} + reply = { + ProtoKey.STATUS: APIStatus.SUCCESS, + ProtoKey.DETAILS: f"content downloaded to {location}", + ProtoKey.META: {MetaKey.LOCATION: location}, + } else: reply = error return reply diff --git a/nvflare/fuel/hci/proto.py b/nvflare/fuel/hci/proto.py index 2459181902..fb0e272b63 100644 --- a/nvflare/fuel/hci/proto.py +++ b/nvflare/fuel/hci/proto.py @@ -66,6 +66,7 @@ class MetaKey(object): CMD_NAME = "cmd_name" TX_ID = "tx_id" FOLDER_NAME = "folder_name" + LOCATION = "location" class MetaStatusValue(object): diff --git a/nvflare/lighter/impl/master_template.yml b/nvflare/lighter/impl/master_template.yml index 329b9c5805..3f58fa65f2 100644 --- a/nvflare/lighter/impl/master_template.yml +++ b/nvflare/lighter/impl/master_template.yml @@ -1157,7 +1157,7 @@ azure_start_svr_sh: | $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ sudo apt-get update && \ sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io - EOF + EOF ) az vm run-command invoke \ --output json \ @@ -1343,7 +1343,7 @@ azure_start_cln_sh: | $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ sudo apt-get update && \ sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io - EOF + EOF ) az vm run-command invoke \ --output json \ diff --git a/nvflare/private/fed/client/client_run_manager.py b/nvflare/private/fed/client/client_run_manager.py index d52fdf2583..6dcf435046 100644 --- a/nvflare/private/fed/client/client_run_manager.py +++ b/nvflare/private/fed/client/client_run_manager.py @@ -128,9 +128,10 @@ def new_context(self) -> FLContext: def send_task_result(self, result: Shareable, fl_ctx: FLContext) -> bool: push_result = self.client.push_results(result, fl_ctx) # push task execution results - if push_result[0] == CellReturnCode.OK: + if push_result == CellReturnCode.OK: return True else: + self.logger.error(f"failed to send task result: {push_result}") return False def get_workspace(self) -> Workspace: diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index ab7543b55b..ab8d0d57d5 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -25,6 +25,7 @@ from nvflare.apis.signal import Signal from nvflare.apis.utils.fl_context_utils import add_job_audit_event from nvflare.apis.utils.task_utils import apply_filters +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.utils.config_service import ConfigService from nvflare.private.defs import SpecialTaskName, TaskConstant from nvflare.private.fed.client.client_engine_executor_spec import ClientEngineExecutorSpec, TaskAssignment @@ -33,6 +34,10 @@ from nvflare.security.logging import secure_format_exception from nvflare.widgets.info_collector import GroupInfoCollector, InfoCollector +_TASK_CHECK_RESULT_OK = 0 +_TASK_CHECK_RESULT_TRY_AGAIN = 1 +_TASK_CHECK_RESULT_TASK_GONE = 2 + class TaskRouter: def __init__(self): @@ -136,6 +141,9 @@ def __init__( self.run_abort_signal = Signal() self.task_lock = threading.Lock() self.running_tasks = {} # task_id => TaskAssignment + + self.task_check_timeout = 5.0 + self.task_check_interval = 5.0 self._register_aux_message_handlers(engine) def find_executor(self, task_name): @@ -398,11 +406,39 @@ def _do_process_task(self, task: TaskAssignment, fl_ctx: FLContext, abort_signal return self._reply_and_audit(reply=reply, ref=server_audit_event_id, fl_ctx=fl_ctx, msg="submit result OK") def _try_run(self): + heartbeat_thread = threading.Thread(target=self._send_job_heartbeat, args=[], daemon=True) + heartbeat_thread.start() + while not self.run_abort_signal.triggered: with self.engine.new_context() as fl_ctx: task_fetch_interval, _ = self.fetch_and_run_one_task(fl_ctx) time.sleep(task_fetch_interval) + def _send_job_heartbeat(self, interval=30.0): + sleep_time = 1.0 + wait_times = int(interval / sleep_time) + if wait_times == 0: + wait_times = 1 + request = Shareable() + while not self.run_abort_signal.triggered: + with self.engine.new_context() as fl_ctx: + self.engine.send_aux_request( + targets=[FQCN.ROOT_SERVER], + topic=ReservedTopic.JOB_HEART_BEAT, + request=request, + timeout=0, + fl_ctx=fl_ctx, + optional=True, + ) + + # we want to send the HB every "interval" secs. + # but we don't want to sleep that long since it will block us from checking abort signal. + # hence we only sleep 1 sec, and check the abort signal. + for i in range(wait_times): + time.sleep(sleep_time) + if self.run_abort_signal.triggered: + break + def fetch_and_run_one_task(self, fl_ctx) -> (float, bool): """Fetches and runs a task. @@ -439,19 +475,105 @@ def fetch_and_run_one_task(self, fl_ctx) -> (float, bool): self.log_debug(fl_ctx, "firing event EventType.BEFORE_SEND_TASK_RESULT") self.fire_event(EventType.BEFORE_SEND_TASK_RESULT, fl_ctx) - reply_sent = self.engine.send_task_result(task_reply, fl_ctx) - if reply_sent: - self.log_info(fl_ctx, "result sent to server for task: name={}, id={}".format(task.name, task.task_id)) - else: - self.log_error( - fl_ctx, - "failed to send result to server for task: name={}, id={}".format(task.name, task.task_id), - ) + self._send_task_result(task_reply, task.task_id, fl_ctx) self.log_debug(fl_ctx, "firing event EventType.AFTER_SEND_TASK_RESULT") self.fire_event(EventType.AFTER_SEND_TASK_RESULT, fl_ctx) return task_fetch_interval, True + def _send_task_result(self, result: Shareable, task_id: str, fl_ctx: FLContext): + try_count = 1 + while True: + self.log_info(fl_ctx, f"try #{try_count}: sending task result to server") + + if self.run_abort_signal.triggered: + self.log_info(fl_ctx, "job aborted: stopped trying to send result") + return False + + try_count += 1 + rc = self._try_send_result_once(result, task_id, fl_ctx) + + if rc == _TASK_CHECK_RESULT_OK: + return True + elif rc == _TASK_CHECK_RESULT_TASK_GONE: + return False + else: + # retry + time.sleep(self.task_check_interval) + + def _try_send_result_once(self, result: Shareable, task_id: str, fl_ctx: FLContext): + # wait until server is ready to receive + while True: + if self.run_abort_signal.triggered: + return _TASK_CHECK_RESULT_TASK_GONE + + rc = self._check_task_once(task_id, fl_ctx) + if rc == _TASK_CHECK_RESULT_OK: + break + elif rc == _TASK_CHECK_RESULT_TASK_GONE: + return rc + else: + # try again + time.sleep(self.task_check_interval) + + # try to send the result + self.log_info(fl_ctx, "start to send task result to server") + reply_sent = self.engine.send_task_result(result, fl_ctx) + if reply_sent: + self.log_info(fl_ctx, "task result sent to server") + return _TASK_CHECK_RESULT_OK + else: + self.log_error(fl_ctx, "failed to send task result to server - will try again") + return _TASK_CHECK_RESULT_TRY_AGAIN + + def _check_task_once(self, task_id: str, fl_ctx: FLContext) -> int: + """This method checks whether the server is still waiting for the specified task. + The real reason for this method is to fight against unstable network connections. + We try to make sure that when we send task result to the server, the connection is available. + If the task check succeeds, then the network connection is likely to be available. + Otherwise, we keep retrying until task check succeeds or the server tells us that the task is gone (timed out). + Args: + task_id: + fl_ctx: + Returns: + """ + self.log_info(fl_ctx, "checking task ...") + task_check_req = Shareable() + task_check_req.set_header(ReservedKey.TASK_ID, task_id) + resp = self.engine.send_aux_request( + targets=[FQCN.ROOT_SERVER], + topic=ReservedTopic.TASK_CHECK, + request=task_check_req, + timeout=self.task_check_timeout, + fl_ctx=fl_ctx, + optional=True, + ) + if resp and isinstance(resp, dict): + reply = resp.get(FQCN.ROOT_SERVER) + if not isinstance(reply, Shareable): + self.log_error(fl_ctx, f"bad task_check reply from server: expect Shareable but got {type(reply)}") + return _TASK_CHECK_RESULT_TRY_AGAIN + + rc = reply.get_return_code() + if rc == ReturnCode.OK: + return _TASK_CHECK_RESULT_OK + elif rc == ReturnCode.COMMUNICATION_ERROR: + self.log_error(fl_ctx, f"failed task_check: {rc}") + return _TASK_CHECK_RESULT_TRY_AGAIN + elif rc == ReturnCode.SERVER_NOT_READY: + self.log_error(fl_ctx, f"server rejected task_check: {rc}") + return _TASK_CHECK_RESULT_TRY_AGAIN + elif rc == ReturnCode.TASK_UNKNOWN: + self.log_error(fl_ctx, f"task no longer exists on server: {rc}") + return _TASK_CHECK_RESULT_TASK_GONE + else: + # this should never happen + self.log_error(fl_ctx, f"programming error: received {rc} from server") + return _TASK_CHECK_RESULT_OK # try to push the result regardless + else: + self.log_error(fl_ctx, f"bad task_check reply from server: invalid resp {type(resp)}") + return _TASK_CHECK_RESULT_TRY_AGAIN + def run(self, app_root, args): self.init_run(app_root, args) diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 0f0f44231b..4c8c96b21a 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -64,6 +64,7 @@ def __init__( cell: CoreCell = None, client_register_interval=2, timeout=5.0, + maint_msg_timeout=5.0, ): """To init the Communicator. @@ -84,6 +85,7 @@ def __init__( self.compression = compression self.client_register_interval = client_register_interval self.timeout = timeout + self.maint_msg_timeout = maint_msg_timeout self.logger = logging.getLogger(self.__class__.__name__) @@ -130,7 +132,7 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.Register, request=login_message, - timeout=self.timeout, + timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: @@ -298,7 +300,7 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.Quit, request=quit_message, - timeout=self.timeout, + timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: @@ -336,7 +338,7 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C channel=CellChannel.SERVER_MAIN, topic=CellChannelTopic.HEART_BEAT, request=heartbeat_message, - timeout=self.timeout, + timeout=self.maint_msg_timeout, ) return_code = result.get_header(MessageHeaderKey.RETURN_CODE) if return_code == ReturnCode.UNAUTHENTICATED: diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 9ff11a2ca7..b62642c8ee 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -15,8 +15,6 @@ import logging import threading import time -from functools import partial -from multiprocessing.dummy import Pool as ThreadPool from typing import List, Optional from nvflare.apis.filter import Filter @@ -40,15 +38,6 @@ from .communicator import Communicator -def _check_progress(remote_tasks): - if remote_tasks[0] is not None: - # shareable = fobs.loads(remote_tasks[0].payload) - shareable = remote_tasks[0].payload - return True, shareable.get_header(ServerCommandKey.TASK_NAME), shareable - else: - return False, None, None - - class FederatedClientBase: """The client-side base implementation of federated learning. @@ -104,6 +93,7 @@ def __init__( cell=cell, client_register_interval=client_args.get("client_register_interval", 2.0), timeout=client_args.get("communication_timeout", 30.0), + maint_msg_timeout=client_args.get("maint_msg_timeout", 5.0), ) self.secure_train = secure_train @@ -336,63 +326,38 @@ def quit_remote(self, project_name, fl_ctx: FLContext): """ return self.communicator.quit_remote(self.servers, project_name, self.token, self.ssid, fl_ctx) + def _get_project_name(self): + """Get name of the project that the site is part of. + + Returns: + + """ + s = tuple(self.servers) # self.servers is a dict of project_name => server config + return s[0] + def heartbeat(self, interval): """Sends a heartbeat from the client to the server.""" - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.send_heartbeat, interval=interval), tuple(self.servers)) - finally: - if pool: - pool.terminate() + return self.send_heartbeat(self._get_project_name(), interval) def pull_task(self, fl_ctx: FLContext): """Fetch remote models and update the local client's session.""" - pool = None - try: - pool = ThreadPool(len(self.servers)) - self.remote_tasks = pool.map(partial(self.fetch_execute_task, fl_ctx=fl_ctx), tuple(self.servers)) - pull_success, task_name, shareable = _check_progress(self.remote_tasks) - # TODO: if some of the servers failed - return pull_success, task_name, shareable - finally: - if pool: - pool.terminate() + result = self.fetch_execute_task(self._get_project_name(), fl_ctx) + if result: + shareable = result.payload + return True, shareable.get_header(ServerCommandKey.TASK_NAME), shareable + else: + return False, None, None def push_results(self, shareable: Shareable, fl_ctx: FLContext): """Push the local model to multiple servers.""" - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.push_execute_result, shareable=shareable, fl_ctx=fl_ctx), tuple(self.servers)) - finally: - if pool: - pool.terminate() + return self.push_execute_result(self._get_project_name(), shareable, fl_ctx) def register(self, fl_ctx: FLContext): - """Push the local model to multiple servers. - - Args: - fl_ctx: FLContext - - Returns: N/A - """ - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.client_register, fl_ctx=fl_ctx), tuple(self.servers)) - finally: - if pool: - pool.terminate() + """Push the local model to multiple servers.""" + return self.client_register(self._get_project_name(), fl_ctx) def set_primary_sp(self, sp): - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.set_sp, sp=sp), tuple(self.servers)) - finally: - if pool: - pool.terminate() + return self.set_sp(self._get_project_name(), sp) def run_heartbeat(self, interval): """Periodically runs the heartbeat.""" @@ -403,6 +368,7 @@ def run_heartbeat(self, interval): def start_heartbeat(self, interval=30): heartbeat_thread = threading.Thread(target=self.run_heartbeat, args=[interval]) + heartbeat_thread.daemon = True heartbeat_thread.start() def logout_client(self, fl_ctx: FLContext): @@ -414,13 +380,7 @@ def logout_client(self, fl_ctx: FLContext): Returns: N/A """ - pool = None - try: - pool = ThreadPool(len(self.servers)) - return pool.map(partial(self.quit_remote, fl_ctx=fl_ctx), tuple(self.servers)) - finally: - if pool: - pool.terminate() + return self.quit_remote(self._get_project_name(), fl_ctx) def set_client_engine(self, engine): self.engine = engine diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index d9635eb6c1..5e782f45f2 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -99,11 +99,19 @@ def __init__(self, config: ServerRunnerConfig, job_id: str, engine: ServerEngine self.current_wf_index = 0 self.status = "init" self.turn_to_cold = False + self._register_aux_message_handler(engine) + def _register_aux_message_handler(self, engine): engine.register_aux_message_handler( topic=ReservedTopic.SYNC_RUNNER, message_handle_func=self._handle_sync_runner ) + engine.register_aux_message_handler( + topic=ReservedTopic.JOB_HEART_BEAT, message_handle_func=self._handle_job_heartbeat + ) + + engine.register_aux_message_handler(topic=ReservedTopic.TASK_CHECK, message_handle_func=self._handle_task_check) + def _handle_sync_runner(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: # simply ack return make_reply(ReturnCode.OK) @@ -475,6 +483,31 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul "Error processing client result by {}: {}".format(self.current_wf.id, secure_format_exception(e)), ) + def _handle_job_heartbeat(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + self.log_info(fl_ctx, "received client job_heartbeat aux request") + return make_reply(ReturnCode.OK) + + def _handle_task_check(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + task_id = request.get_header(ReservedHeaderKey.TASK_ID) + if not task_id: + self.log_error(fl_ctx, f"missing {ReservedHeaderKey.TASK_ID} in task_check request") + return make_reply(ReturnCode.BAD_REQUEST_DATA) + + self.log_info(fl_ctx, f"received task_check on task {task_id}") + + with self.wf_lock: + if self.current_wf is None or self.current_wf.responder is None: + self.log_info(fl_ctx, "no current workflow - dropped task_check.") + return make_reply(ReturnCode.TASK_UNKNOWN) + + task = self.current_wf.responder.process_task_check(task_id=task_id, fl_ctx=fl_ctx) + if task: + self.log_info(fl_ctx, f"task {task_id} is still good") + return make_reply(ReturnCode.OK) + else: + self.log_info(fl_ctx, f"task {task_id} is not found") + return make_reply(ReturnCode.TASK_UNKNOWN) + def abort(self, fl_ctx: FLContext, turn_to_cold: bool = False): self.status = "done" self.abort_signal.trigger(value=True) diff --git a/nvflare/private/fed/utils/fed_utils.py b/nvflare/private/fed/utils/fed_utils.py index b6cdcc534f..ba39d26e23 100644 --- a/nvflare/private/fed/utils/fed_utils.py +++ b/nvflare/private/fed/utils/fed_utils.py @@ -18,7 +18,6 @@ import os import sys from logging.handlers import RotatingFileHandler -from multiprocessing.connection import Listener from typing import List from nvflare.apis.app_validation import AppValidator @@ -39,7 +38,7 @@ from nvflare.private.event import fire_event from nvflare.private.fed.utils.decomposers import private_decomposers from nvflare.private.privacy_manager import PrivacyManager, PrivacyService -from nvflare.security.logging import secure_format_exception, secure_log_traceback +from nvflare.security.logging import secure_format_exception from nvflare.security.security import EmptyAuthorizer, FLAuthorizer from .app_authz import AppAuthzService @@ -54,28 +53,6 @@ def add_logfile_handler(log_file): root_logger.addHandler(file_handler) -def listen_command(listen_port, engine, execute_func, logger): - conn = None - listener = None - try: - address = ("localhost", listen_port) - listener = Listener(address, authkey="client process secret password".encode()) - conn = listener.accept() - - execute_func(conn, engine) - - except Exception as e: - logger.exception( - f"Could not create the listener for this process on port: {listen_port}: {secure_format_exception(e)}." - ) - secure_log_traceback(logger) - finally: - if conn: - conn.close() - if listener: - listener.close() - - def _check_secure_content(site_type: str) -> List[str]: """To check the security contents. diff --git a/tests/unit_test/fuel/f3/communicator_test.py b/tests/unit_test/fuel/f3/communicator_test.py index b08bf56b18..45cbd178d6 100644 --- a/tests/unit_test/fuel/f3/communicator_test.py +++ b/tests/unit_test/fuel/f3/communicator_test.py @@ -90,6 +90,7 @@ class TestCommunicator: [ ("tcp", "2000-3000"), ("grpc", "3000-4000"), + ("nagrpc", "4000-5000"), # ("http", "4000-5000"), TODO (YT): We disable this, as it is causing our jenkins hanging ("atcp", "5000-6000"), ], diff --git a/tests/unit_test/fuel/f3/drivers/custom_driver_test.py b/tests/unit_test/fuel/f3/drivers/custom_driver_test.py index 5137f52db5..90583c653c 100644 --- a/tests/unit_test/fuel/f3/drivers/custom_driver_test.py +++ b/tests/unit_test/fuel/f3/drivers/custom_driver_test.py @@ -18,12 +18,14 @@ from nvflare.fuel.f3 import communicator # Setup custom driver path before communicator module initialization +from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.utils.config_service import ConfigService class TestCustomDriver: @pytest.fixture def manager(self): + CommConfigurator.reset() rel_path = "../../../data/custom_drivers/config" config_path = os.path.normpath(os.path.join(os.path.dirname(__file__), rel_path)) ConfigService.initialize({}, [config_path]) diff --git a/tests/unit_test/fuel/f3/drivers/driver_manager_test.py b/tests/unit_test/fuel/f3/drivers/driver_manager_test.py index a653a47af6..438ccea80b 100644 --- a/tests/unit_test/fuel/f3/drivers/driver_manager_test.py +++ b/tests/unit_test/fuel/f3/drivers/driver_manager_test.py @@ -20,6 +20,7 @@ from nvflare.fuel.f3.drivers.aio_http_driver import AioHttpDriver from nvflare.fuel.f3.drivers.aio_tcp_driver import AioTcpDriver from nvflare.fuel.f3.drivers.driver_manager import DriverManager +from nvflare.fuel.f3.drivers.grpc_driver import GrpcDriver from nvflare.fuel.f3.drivers.tcp_driver import TcpDriver @@ -37,6 +38,8 @@ def manager(self): ("stcp", TcpDriver), ("grpc", AioGrpcDriver), ("grpcs", AioGrpcDriver), + ("nagrpc", GrpcDriver), + ("nagrpcs", GrpcDriver), ("http", AioHttpDriver), ("https", AioHttpDriver), ("ws", AioHttpDriver),