From 37cf9d0ff5b798551fbbb0d9a57629bed178671c Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Thu, 26 Sep 2024 14:14:48 -0700 Subject: [PATCH] Update flwr job object, client, server --- examples/hello-world/hello-flower/README.md | 2 +- .../flwr-pt-tb/flwr_pt_tb/client.py | 10 +-- .../flwr-pt-tb/flwr_pt_tb/server.py | 17 ++-- examples/hello-world/hello-flower/job.py | 8 +- nvflare/app_opt/flower/flower_job.py | 6 +- nvflare/app_opt/flower/flower_pt_job.py | 88 +++++++++++++++++++ 6 files changed, 110 insertions(+), 21 deletions(-) create mode 100644 nvflare/app_opt/flower/flower_pt_job.py diff --git a/examples/hello-world/hello-flower/README.md b/examples/hello-world/hello-flower/README.md index b896913c5e..6ef847ad54 100644 --- a/examples/hello-world/hello-flower/README.md +++ b/examples/hello-world/hello-flower/README.md @@ -47,5 +47,5 @@ Next, we run 2 Flower clients and Flower Server in parallel using NVFlare while the TensorBoard metrics to the server at each iteration using NVFlare's metric streaming. ```bash -python job.py --job_name "flwr-pt-tb" --content_dir "./flwr-pt-tb" --stream_metrics --use_client_api +python job.py --job_name "flwr-pt-tb" --content_dir "./flwr-pt-tb" --stream_metrics ``` diff --git a/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/client.py b/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/client.py index c79d475ff6..35f6668483 100644 --- a/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/client.py +++ b/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/client.py @@ -36,18 +36,16 @@ class FlowerClient(NumPyClient): def __init__(self, context: Context): super().__init__() self.writer = SummaryWriter() - self.set_context(context) + self.flwr_context = context + if "step" not in context.state.metrics_records: self.set_step(0) def set_step(self, step: int): - context = self.get_context() - context.state = RecordSet(metrics_records={"step": MetricsRecord({"step": step})}) - self.set_context(context) + self.flwr_context.state = RecordSet(metrics_records={"step": MetricsRecord({"step": step})}) def get_step(self): - context = self.get_context() - return int(context.state.metrics_records["step"]["step"]) + return int(self.flwr_context.state.metrics_records["step"]["step"]) def fit(self, parameters, config): step = self.get_step() diff --git a/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/server.py b/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/server.py index 3093d3b5e8..0ee418c3b4 100644 --- a/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/server.py +++ b/examples/hello-world/hello-flower/flwr-pt-tb/flwr_pt_tb/server.py @@ -13,8 +13,8 @@ # limitations under the License. from typing import List, Tuple -from flwr.common import Metrics, ndarrays_to_parameters -from flwr.server import ServerApp, ServerConfig +from flwr.common import Context, Metrics, ndarrays_to_parameters +from flwr.server import ServerApp, ServerAppComponents, ServerConfig from flwr.server.strategy import FedAvg from .task import Net, get_weights @@ -53,13 +53,16 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: initial_parameters=parameters, ) - # Define config config = ServerConfig(num_rounds=3) # Flower ServerApp -app = ServerApp( - config=config, - strategy=strategy, -) +def server_fn(context: Context): + return ServerAppComponents( + strategy=strategy, + config=config, + ) + + +app = ServerApp(server_fn=server_fn) diff --git a/examples/hello-world/hello-flower/job.py b/examples/hello-world/hello-flower/job.py index 558c29ed4e..7ebed39778 100644 --- a/examples/hello-world/hello-flower/job.py +++ b/examples/hello-world/hello-flower/job.py @@ -14,7 +14,7 @@ from argparse import ArgumentParser -from nvflare.app_opt.flower.flower_job import FlowerJob +from nvflare.app_opt.flower.flower_pt_job import FlowerPyTorchJob from nvflare.client.api import ClientAPIType from nvflare.client.api_spec import CLIENT_API_TYPE_KEY @@ -30,10 +30,12 @@ def main(): args = parser.parse_args() env = {} - if args.use_client_api: + if args.stream_metrics or args.use_client_api: + # needs to init client api to stream metrics + # only external client api works with the current flower integration env = {CLIENT_API_TYPE_KEY: ClientAPIType.EX_PROCESS_API.value} - job = FlowerJob( + job = FlowerPyTorchJob( name=args.job_name, flower_content=args.content_dir, stream_metrics=args.stream_metrics, diff --git a/nvflare/app_opt/flower/flower_job.py b/nvflare/app_opt/flower/flower_job.py index 97bcdbe899..adce76bc26 100644 --- a/nvflare/app_opt/flower/flower_job.py +++ b/nvflare/app_opt/flower/flower_job.py @@ -19,7 +19,6 @@ from nvflare.app_common.widgets.external_configurator import ExternalConfigurator from nvflare.app_common.widgets.metric_relay import MetricRelay from nvflare.app_common.widgets.streaming import AnalyticsReceiver -from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver from nvflare.fuel.utils.pipe.cell_pipe import CellPipe from nvflare.fuel.utils.validation_utils import check_object_type from nvflare.job_config.api import FedJob @@ -104,10 +103,9 @@ def __init__( # server side - need analytics_receiver if analytics_receiver: check_object_type("analytics_receiver", analytics_receiver, AnalyticsReceiver) + self.to_server(analytics_receiver, "analytics_receiver") else: - analytics_receiver = TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) - - self.to_server(analytics_receiver, "analytics_receiver") + raise ValueError("Missing analytics receiver on the server side.") # client side # cell pipe diff --git a/nvflare/app_opt/flower/flower_pt_job.py b/nvflare/app_opt/flower/flower_pt_job.py new file mode 100644 index 0000000000..bd350e6c37 --- /dev/null +++ b/nvflare/app_opt/flower/flower_pt_job.py @@ -0,0 +1,88 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from nvflare.app_common.tie.defs import Constant +from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver + +from .flower_job import FlowerJob + + +class FlowerPyTorchJob(FlowerJob): + def __init__( + self, + name: str, + flower_content: str, + min_clients: int = 1, + mandatory_clients: Optional[List[str]] = None, + database: str = "", + server_app_args: list = None, + superlink_ready_timeout: float = 10.0, + configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, + start_task_timeout=Constant.START_TASK_TIMEOUT, + max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL, + progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, + per_msg_timeout=10.0, + tx_timeout=100.0, + client_shutdown_timeout=5.0, + stream_metrics=False, + analytics_receiver=None, + extra_env: dict = None, + ): + """ + Flower Job. + + Args: + name (str): Name of the job. + flower_content (str): Content for the flower job. + min_clients (int, optional): The minimum number of clients for the job. Defaults to 1. + mandatory_clients (List[str], optional): List of mandatory clients for the job. Defaults to None. + database (str, optional): Database string. Defaults to "". + server_app_args (list, optional): List of arguments to pass to the server application. Defaults to None. + superlink_ready_timeout (float, optional): Timeout for the superlink to be ready. Defaults to 10.0 seconds. + configure_task_timeout (float, optional): Timeout for configuring the task. Defaults to Constant.CONFIG_TASK_TIMEOUT. + start_task_timeout (float, optional): Timeout for starting the task. Defaults to Constant.START_TASK_TIMEOUT. + max_client_op_interval (float, optional): Maximum interval between client operations. Defaults to Constant.MAX_CLIENT_OP_INTERVAL. + progress_timeout (float, optional): Timeout for workflow progress. Defaults to Constant.WORKFLOW_PROGRESS_TIMEOUT. + per_msg_timeout (float, optional): Timeout for receiving individual messages. Defaults to 10.0 seconds. + tx_timeout (float, optional): Timeout for transmitting data. Defaults to 100.0 seconds. + client_shutdown_timeout (float, optional): Timeout for client shutdown. Defaults to 5.0 seconds. + stream_metrics (bool, optional): Whether to stream metrics from Flower client to Flare + analytics_receiver (AnalyticsReceiver, optional): the AnalyticsReceiver to use to process received metrics. + extra_env (dict, optional): optional extra env variables to be passed to Flower client + """ + analytics_receiver = ( + analytics_receiver if analytics_receiver else TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) + ) + + super().__init__( + name=name, + flower_content=flower_content, + min_clients=min_clients, + mandatory_clients=mandatory_clients, + database=database, + server_app_args=server_app_args, + superlink_ready_timeout=superlink_ready_timeout, + configure_task_timeout=configure_task_timeout, + start_task_timeout=start_task_timeout, + max_client_op_interval=max_client_op_interval, + progress_timeout=progress_timeout, + per_msg_timeout=per_msg_timeout, + tx_timeout=tx_timeout, + client_shutdown_timeout=client_shutdown_timeout, + stream_metrics=stream_metrics, + analytics_receiver=analytics_receiver, + extra_env=extra_env, + )