Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz authored May 13, 2024
2 parents 07d3787 + d050cf2 commit 98289c4
Show file tree
Hide file tree
Showing 36 changed files with 1,477 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"id": "mlflow_receiver_with_tracking_uri",
"path": "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver",
"args": {
tracking_uri = "file:///{WORKSPACE}/{JOB_ID}/mlruns"
"kwargs": {
"experiment_name": "hello-pt-experiment",
"run_name": "hello-pt-with-mlflow",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def evaluate(input_weights):
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
global_step = input_model.current_round * local_epochs * batch_size + epoch * batch_size + i
global_step = input_model.current_round * steps + epoch * len(trainloader) + i
mlflow.log_metric("loss", running_loss / 2000, global_step)
running_loss = 0.0

Expand Down
46 changes: 44 additions & 2 deletions examples/hello-world/step-by-step/cifar10/sag/sag.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@
"source": [
"! nvflare job create -j /tmp/nvflare/jobs/cifar10_sag_pt -w sag_pt_in_proc \\\n",
"-f meta.conf min_clients=2 \\\n",
"-f config_fed_client.conf app_script=train.py app_config=\"--batch_size 4 --dataset_path {CIFAR10_ROOT} --num_workers 2\" \\\n",
"-f config_fed_server.conf num_rounds=5 \\\n",
"-f config_fed_client.conf app_script=train_with_mlflow.py app_config=\"--batch_size 4 --dataset_path {CIFAR10_ROOT} --num_workers 2\" \\\n",
"-f config_fed_server.conf num_rounds=2 \\\n",
"-sd ../code/fl \\\n",
"-force"
]
Expand Down Expand Up @@ -289,6 +289,48 @@
"The next 5 examples will use the same ScatterAndGather workflow, but will demonstrate different execution APIs and feature.\n",
"In the next example [sag_deploy_map](../sag_deploy_map/sag_deploy_map.ipynb), we will learn about the deploy_map configuration for deployment of apps to different sites."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a49b430b-a65b-4b1e-8793-9b3befcfcfd9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!tree /tmp/nvflare/jobs/cifar10_sag_pt_workspace/"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50594df7-b4c9-4e5e-944a-403b5a105c27",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!mlflow ui --port 5000 --backend-store-uri /tmp/nvflare/jobs/cifar10_sag_pt_workspace/server/simulate_job/mlruns"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af2b6628-61af-4bc8-84d4-a9876a27c7c2",
"metadata": {},
"outputs": [],
"source": [
"!tensorboard --logdir=/tmp/nvflare/jobs/cifar10_sag_pt_workspace/server/simulate_job/tb_events"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3ad11c3-6ef7-46cd-8778-0090505b14e1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
2 changes: 1 addition & 1 deletion job_templates/sag_pt_in_proc/config_fed_server.conf
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
path = "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver"
args {
# tracking_uri = "http://0.0.0.0:5000"
tracking_uri = ""
tracking_uri = "file:///{WORKSPACE}/{JOB_ID}/mlruns"
kwargs {
experiment_name = "nvflare-sag-pt-experiment"
run_name = "nvflare-sag-pt-with-mlflow"
Expand Down
15 changes: 11 additions & 4 deletions nvflare/app_common/executors/in_process_client_api_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(
submit_model_task_name: str = "submit_model",
):
super(InProcessClientAPIExecutor, self).__init__()
self._abort = False
self._client_api = None
self._result_pull_interval = result_pull_interval
self._log_pull_interval = log_pull_interval
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
self._event_manager = EventManager(self._data_bus)
self._data_bus.subscribe([TOPIC_LOCAL_RESULT], self.local_result_callback)
self._data_bus.subscribe([TOPIC_LOG_DATA], self.log_result_callback)
self._data_bus.subscribe([TOPIC_ABORT, TOPIC_STOP], self.to_abort_callback)
self.local_result = None
self._fl_ctx = None
self._task_fn_path = None
Expand All @@ -106,17 +108,19 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
self._init_converter(fl_ctx)

self._task_fn_wrapper = TaskScriptRunner(
script_path=self._task_script_path, script_args=self._task_script_args
site_name=fl_ctx.get_identity_name(),
script_path=self._task_script_path,
script_args=self._task_script_args,
)

self._task_fn_thread = threading.Thread(target=self._task_fn_wrapper.run)
self._task_fn_thread.start()

meta = self._prepare_task_meta(fl_ctx, None)
self._client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=self._result_pull_interval)
self._client_api.init()
self._data_bus.put_data(CLIENT_API_KEY, self._client_api)

self._task_fn_thread.start()

elif event_type == EventType.END_RUN:
self._event_manager.fire_event(TOPIC_STOP, "END_RUN received")
if self._task_fn_thread:
Expand All @@ -142,7 +146,7 @@ def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort
# wait for result
self.log_info(fl_ctx, "Waiting for result from peer")
while True:
if abort_signal.triggered:
if abort_signal.triggered or self._abort is True:
# notify peer that the task is aborted
self._event_manager.fire_event(TOPIC_ABORT, f"{task_name}' is aborted, abort_signal_triggered")
return make_reply(ReturnCode.TASK_ABORTED)
Expand Down Expand Up @@ -231,3 +235,6 @@ def log_result_callback(self, topic, data, databus):
# fire_fed_event = True w/o fed_event_converter somehow did not work
with self._engine.new_context() as fl_ctx:
send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=ANALYTIC_EVENT_TYPE, fire_fed_event=False)

def to_abort_callback(self, topic, data, databus):
self._abort = True
60 changes: 43 additions & 17 deletions nvflare/app_common/executors/task_script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,75 +14,101 @@
import builtins
import logging
import os
import runpy
import sys
import traceback

from nvflare.client.in_process.api import TOPIC_ABORT
from nvflare.fuel.data_event.data_bus import DataBus
from nvflare.fuel.data_event.event_manager import EventManager

print_fn = builtins.print


class TaskScriptRunner:
logger = logging.getLogger(__name__)

def __init__(self, script_path: str, script_args: str = None):
def __init__(self, site_name: str, script_path: str, script_args: str = None, redirect_print_to_log=True):
"""Wrapper for function given function path and args
Args:
site_name (str): site name
script_path (str): script file name, such as train.py
script_args (str, Optional): script arguments to pass in.
"""

self.redirect_print_to_log = redirect_print_to_log
self.event_manager = EventManager(DataBus())
self.script_args = script_args
self.client_api = None
self.site_name = site_name
self.logger = logging.getLogger(self.__class__.__name__)
self.script_path = self.get_script_full_path(script_path)
self.script_path = script_path
self.script_full_path = self.get_script_full_path(self.site_name, self.script_path)

def run(self):
"""Call the task_fn with any required arguments."""
self.logger.info(f"\n start task run() with {self.script_path}")
self.logger.info(f"\n start task run() with full path: {self.script_full_path}")
try:
import runpy

curr_argv = sys.argv
builtins.print = log_print
builtins.print = log_print if self.redirect_print_to_log else print_fn
sys.argv = self.get_sys_argv()
runpy.run_path(self.script_path, run_name="__main__")
runpy.run_path(self.script_full_path, run_name="__main__")
sys.argv = curr_argv

except ImportError as ie:
msg = "attempted relative import with no known parent package"
if ie.msg == msg:
xs = [p for p in sys.path if self.script_full_path.startswith(p)]
import_base_path = max(xs, key=len)
raise ImportError(
f"{ie.msg}, the relative import is not support. python import is based off the sys.path: {import_base_path}"
)
else:
raise ie
except Exception as e:
msg = traceback.format_exc()
self.logger.error(msg)
if self.client_api:
self.client_api.exec_queue.ask_abort(msg)
self.logger.error("fire abort event")
self.event_manager.fire_event(TOPIC_ABORT, f"'{self.script_full_path}' is aborted, {msg}")
raise e
finally:
builtins.print = print_fn

def get_sys_argv(self):
args_list = [] if not self.script_args else self.script_args.split()
return [self.script_path] + args_list
return [self.script_full_path] + args_list

def get_script_full_path(self, script_path) -> str:
def get_script_full_path(self, site_name, script_path) -> str:
target_file = None
script_filename = os.path.basename(script_path)
script_dirs = os.path.dirname(script_path)

if os.path.isabs(script_path):
if not os.path.isfile(script_path):
raise ValueError(f"script_path='{script_path}' not found")
return script_path

for r, dirs, files in os.walk(os.getcwd()):
for f in files:
absolute_path = os.path.join(r, f)
if absolute_path.endswith(script_path):
parent_dir = absolute_path[: absolute_path.find(script_path)].rstrip(os.sep)
if os.path.isdir(parent_dir):
target_file = absolute_path
break
path_components = parent_dir.split(os.path.sep)
if site_name in path_components:
target_file = absolute_path
break

if not script_dirs and f == script_filename:
if not site_name and not script_dirs and f == script_filename:
target_file = absolute_path
break

if target_file:
break

if not target_file:
raise ValueError(f"{script_path} is not found")
msg = f"Can not find {script_path}"
self.event_manager.fire_event(TOPIC_ABORT, f"'{self.script_path}' is aborted, {msg}")
raise ValueError(msg)
return target_file


Expand Down
9 changes: 7 additions & 2 deletions nvflare/client/ipc/ipc_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,13 @@ def __init__(
self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_HEARTBEAT, cb=self._handle_heartbeat)
self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_BYE, cb=self._handle_bye)
self.cell.register_request_cb(channel=defs.CHANNEL, topic=defs.TOPIC_ABORT, cb=self._handle_abort_task)
self.cell.add_incoming_request_filter(
channel=defs.CHANNEL,
self.cell.core_cell.add_incoming_request_filter(
channel="*",
topic="*",
cb=self._msg_received,
)
self.cell.core_cell.add_incoming_reply_filter(
channel="*",
topic="*",
cb=self._msg_received,
)
Expand Down
18 changes: 17 additions & 1 deletion nvflare/dashboard/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ def cloud(args):
"t",
exe=True,
)
print(f"Dashboard launch script for cloud is written at {dest}. Now running the script.")
print(f"Dashboard launch script for cloud is written at {dest}. Now running it.")
if args.vpc_id and args.subnet_id:
option = [f"--vpc-id={args.vpc_id}", f"--subnet-id={args.subnet_id}"]
print(f"Option of the script: {option}")
dest = [dest] + option
_ = subprocess.run(dest)
os.remove(dest)

Expand Down Expand Up @@ -192,6 +196,18 @@ def define_dashboard_parser(parser):
parser.add_argument("--cred", help="set credential directly in the form of USER_EMAIL:PASSWORD")
parser.add_argument("-i", "--image", help="set the container image name")
parser.add_argument("--local", action="store_true", help="start dashboard locally without docker image")
parser.add_argument(
"--vpc-id",
type=str,
default="",
help="VPC id for AWS EC2 instance. Applicable to AWS only. Ignored if subnet-id is not specified.",
)
parser.add_argument(
"--subnet-id",
type=str,
default="",
help="Subnet id for AWS EC2 instance. Applicable to AWS only. Ignored if vpc-id is not specified.",
)


def handle_dashboard(args):
Expand Down
3 changes: 1 addition & 2 deletions nvflare/fuel/f3/cellnet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
from typing import Dict, List, Union

from nvflare.fuel.f3.cellnet.core_cell import CoreCell, TargetMessage
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, MessageType, ReturnCode
from nvflare.fuel.f3.cellnet.defs import CellChannel, MessageHeaderKey, MessageType, ReturnCode
from nvflare.fuel.f3.cellnet.utils import decode_payload, encode_payload, make_reply
from nvflare.fuel.f3.message import Message
from nvflare.fuel.f3.stream_cell import StreamCell
from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey
from nvflare.fuel.f3.streaming.stream_types import StreamFuture
from nvflare.private.defs import CellChannel
from nvflare.security.logging import secure_format_exception

CHANNELS_TO_EXCLUDE = (
Expand Down
36 changes: 36 additions & 0 deletions nvflare/fuel/f3/cellnet/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,39 @@ class AbortRun(Exception):

class InvalidRequest(Exception):
pass


class SSLConstants:
"""hard coded names related to SSL."""

CERT = "ssl_cert"
PRIVATE_KEY = "ssl_private_key"
ROOT_CERT = "ssl_root_cert"


class CellChannel:

CLIENT_MAIN = "admin"
AUX_COMMUNICATION = "aux_communication"
SERVER_MAIN = "task"
SERVER_COMMAND = "server_command"
SERVER_PARENT_LISTENER = "server_parent_listener"
CLIENT_COMMAND = "client_command"
CLIENT_SUB_WORKER_COMMAND = "client_sub_worker_command"
MULTI_PROCESS_EXECUTOR = "multi_process_executor"
SIMULATOR_RUNNER = "simulator_runner"
RETURN_ONLY = "return_only"


class CellChannelTopic:

Register = "register"
Quit = "quit"
GET_TASK = "get_task"
SUBMIT_RESULT = "submit_result"
HEART_BEAT = "heart_beat"
EXECUTE_RESULT = "execute_result"
FIRE_EVENT = "fire_event"
REPORT_JOB_FAILURE = "report_job_failure"

SIMULATOR_WORKER_INIT = "simulator_worker_init"
Loading

0 comments on commit 98289c4

Please sign in to comment.