From 14d6b696e840192a6315f51f226c3b4a34dc2731 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 21 Feb 2024 12:35:39 -0500 Subject: [PATCH 01/44] WIP: tdx_cc integration. --- .../confidential_computing/cc_manager.py | 9 +- .../confidential_computing/tdx_connector.py | 109 ++++++++++++++++++ .../private/fed/app/client/client_train.py | 1 + .../fed/app/deployer/server_deployer.py | 3 +- 4 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 nvflare/app_opt/confidential_computing/tdx_connector.py diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 82fa7dba8a..7cc305c0d0 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -17,8 +17,9 @@ from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError +from nvflare.app_opt.confidential_computing.tdx_connector import TDXCCHelper -from .cc_helper import CCHelper +# from .cc_helper import CCHelper PEER_CTX_CC_TOKEN = "_peer_ctx_cc_token" CC_TOKEN = "_cc_token" @@ -157,7 +158,11 @@ def _add_client_token(self, fl_ctx: FLContext): def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: # both server and client sides self.site_name = fl_ctx.get_identity_name() - self.helper = CCHelper(site_name=self.site_name, verifiers=self.verifiers) + workspace_folder = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_site_config_dir() + # self.helper = CCHelper(site_name=self.site_name, verifiers=self.verifiers) + self.helper = TDXCCHelper(site_name=self.site_name, + tdx_cli_command="/home/azureuser/TDX/client/tdx-cli/trustauthority-cli", + config_dir=workspace_folder) ok = self.helper.prepare() if not ok: return "failed to attest" diff --git a/nvflare/app_opt/confidential_computing/tdx_connector.py b/nvflare/app_opt/confidential_computing/tdx_connector.py new file mode 100644 index 0000000000..abe9fa9ca4 --- /dev/null +++ b/nvflare/app_opt/confidential_computing/tdx_connector.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import shlex +import subprocess +from typing import Dict + +# TDX_CLI_COMMAND = "./trustauthority-cli" +TDX_CLI_CONFIG = "config.json" +TOKEN_FILE = "token.txt" +VERIFY_FILE = "verify.txt" +ERROR_FILE = "error.txt" + + +class TDXConnector: + def __init__(self, tdx_cli_command: str, config_dir: str) -> None: + super().__init__() + self.tdx_cli_command = tdx_cli_command + self.config_dir = config_dir + + self.config_file = os.path.join(self.config_dir, TDX_CLI_CONFIG) + + def get_token(self): + out = open(os.path.join(self.config_dir, TOKEN_FILE), "w") + err_out = open(os.path.join(self.config_dir, ERROR_FILE), "w") + + command = "sudo " + self.tdx_cli_command + " -c " + self.config_file + " token --no-eventlog " + process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, stdout=out, stderr=err_out) + process.wait() + + with open(TOKEN_FILE, "r") as f: + token = f.readline() + with open(ERROR_FILE, "r") as f: + if 'Error:' in f.read(): + error = True + else: + error = False + + return token, error + + def verify_token(self, token: str): + out = open(os.path.join(self.config_dir, VERIFY_FILE), "w") + err_out = open(os.path.join(self.config_dir, ERROR_FILE), "w") + + command = self.tdx_cli_command + " verify --config " + self.config_file + " --token " + token + process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, stdout=out, stderr=err_out) + process.wait() + + # with open(VERIFY_FILE, "r") as f: + # result = f.readline() + with open(ERROR_FILE, "r") as f: + if 'Error:' in f.read(): + return False + + return True + + +class TDXCCHelper: + + def __init__(self, site_name: str, tdx_cli_command: str, config_dir: str) -> None: + super().__init__() + self.site_name = site_name + # self.tdx_cli_command = tdx_cli_command + # self.config_dir = config_dir + self.token = None + + self.tdx_connector = TDXConnector(tdx_cli_command, config_dir) + self.logger = logging.getLogger(self.__class__.__name__) + + def prepare(self) -> bool: + self.token = self.tdx_connector.get_token() + self.logger.info(f"site: {self.site_name} got the token: {self.token}") + return True + + def get_token(self): + return self.token + + def validate_participants(self, participants: Dict[str, str]) -> Dict[str, bool]: + result = {} + if not participants: + return result + for k, v in participants.items(): + if self.tdx_connector.verify_token(v): + result[k] = True + self.logger.info(f"CC - results from validating participants' tokens: {result}") + return result + + +if __name__ == "__main__": + tdx_connector = TDXConnector() + token, error = tdx_connector.get_token() + print("--- Acquire the token ---") + print(token) + + result = tdx_connector.verify_token(token) + print("---- Verify the token ---") + print(result) diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index f933c14349..aa68ad5adb 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -108,6 +108,7 @@ def main(args): time.sleep(1.0) with client_engine.new_context() as fl_ctx: + fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) client_engine.fire_event(EventType.BEFORE_CLIENT_REGISTER, fl_ctx) diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index e800820497..899832dcba 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -16,7 +16,7 @@ import threading from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import SystemComponents +from nvflare.apis.fl_constant import SystemComponents, FLContextKey from nvflare.apis.workspace import Workspace from nvflare.private.fed.server.fed_server import FederatedServer from nvflare.private.fed.server.job_runner import JobRunner @@ -119,6 +119,7 @@ def deploy(self, args): run_manager.add_component(SystemComponents.JOB_RUNNER, job_runner) with services.engine.new_context() as fl_ctx: + fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) services.engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) threading.Thread(target=self._start_job_runner, args=[job_runner, fl_ctx]).start() From bc6b21e1c9e2ef6e7a816540e8c75b15bf349ff2 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 21 Feb 2024 12:55:18 -0500 Subject: [PATCH 02/44] fixed toke_file read. --- .../confidential_computing/tdx_connector.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/tdx_connector.py b/nvflare/app_opt/confidential_computing/tdx_connector.py index abe9fa9ca4..7203b38996 100644 --- a/nvflare/app_opt/confidential_computing/tdx_connector.py +++ b/nvflare/app_opt/confidential_computing/tdx_connector.py @@ -33,16 +33,18 @@ def __init__(self, tdx_cli_command: str, config_dir: str) -> None: self.config_file = os.path.join(self.config_dir, TDX_CLI_CONFIG) def get_token(self): - out = open(os.path.join(self.config_dir, TOKEN_FILE), "w") - err_out = open(os.path.join(self.config_dir, ERROR_FILE), "w") + token_file = os.path.join(self.config_dir, TOKEN_FILE) + out = open(token_file, "w") + error_file = os.path.join(self.config_dir, ERROR_FILE) + err_out = open(error_file, "w") command = "sudo " + self.tdx_cli_command + " -c " + self.config_file + " token --no-eventlog " process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, stdout=out, stderr=err_out) process.wait() - with open(TOKEN_FILE, "r") as f: + with open(token_file, "r") as f: token = f.readline() - with open(ERROR_FILE, "r") as f: + with open(error_file, "r") as f: if 'Error:' in f.read(): error = True else: @@ -52,7 +54,8 @@ def get_token(self): def verify_token(self, token: str): out = open(os.path.join(self.config_dir, VERIFY_FILE), "w") - err_out = open(os.path.join(self.config_dir, ERROR_FILE), "w") + error_file = os.path.join(self.config_dir, ERROR_FILE) + err_out = open(error_file, "w") command = self.tdx_cli_command + " verify --config " + self.config_file + " --token " + token process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, stdout=out, stderr=err_out) @@ -60,7 +63,7 @@ def verify_token(self, token: str): # with open(VERIFY_FILE, "r") as f: # result = f.readline() - with open(ERROR_FILE, "r") as f: + with open(error_file, "r") as f: if 'Error:' in f.read(): return False From 9f6b71592d9243a55949d77d7b9f6155d7563553 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 21 Feb 2024 14:08:02 -0500 Subject: [PATCH 03/44] WIP: added info for CC add client tokens.: --- nvflare/app_opt/confidential_computing/cc_manager.py | 4 +++- nvflare/app_opt/confidential_computing/tdx_connector.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 7cc305c0d0..0245b026ea 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -155,6 +155,8 @@ def _add_client_token(self, fl_ctx: FLContext): self.participant_cc_info[token_owner] = peer_cc_info self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False + self.logger.info(f"Added CC client: {token_owner}") + def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: # both server and client sides self.site_name = fl_ctx.get_identity_name() @@ -213,7 +215,7 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: for p in participant_tokens: self.participant_cc_info[p][CC_TOKEN_VALIDATED] = True fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=True, private=False) - self.logger.debug(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") + self.logger.info(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") return "" def _validate_participants_tokens(self, participants) -> str: diff --git a/nvflare/app_opt/confidential_computing/tdx_connector.py b/nvflare/app_opt/confidential_computing/tdx_connector.py index 7203b38996..24a22a5302 100644 --- a/nvflare/app_opt/confidential_computing/tdx_connector.py +++ b/nvflare/app_opt/confidential_computing/tdx_connector.py @@ -83,9 +83,9 @@ def __init__(self, site_name: str, tdx_cli_command: str, config_dir: str) -> Non self.logger = logging.getLogger(self.__class__.__name__) def prepare(self) -> bool: - self.token = self.tdx_connector.get_token() + self.token, error = self.tdx_connector.get_token() self.logger.info(f"site: {self.site_name} got the token: {self.token}") - return True + return not error def get_token(self): return self.token From dd281f57b82e6ceb1dde25602e99000eabe96014 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 21 Feb 2024 17:02:33 -0500 Subject: [PATCH 04/44] Fixed an error when client does not have CC token reported. --- nvflare/app_opt/confidential_computing/cc_manager.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 0245b026ea..043ba7cfc7 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -206,7 +206,10 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: continue if p not in self.participant_cc_info: return f"no token available for participant {p}" - participant_tokens[p] = self.participant_cc_info[p][CC_TOKEN] + if self.participant_cc_info.get(p): + participant_tokens[p] = self.participant_cc_info[p][CC_TOKEN] + else: + participant_tokens[p] = "" err = self._validate_participants_tokens(participant_tokens) if err: From 4d87722440b30ef9e95d8add5c67a5b867ab4099 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 21 Feb 2024 17:06:52 -0500 Subject: [PATCH 05/44] Added handle for client does not have CC_INFO. --- nvflare/app_opt/confidential_computing/cc_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 043ba7cfc7..aa45f7b98a 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -151,7 +151,7 @@ def _add_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() - peer_cc_info = peer_ctx.get_prop(CC_INFO) + peer_cc_info = peer_ctx.get_prop(CC_INFO, {CC_TOKEN: "", CC_TOKEN_VALIDATED: False}) self.participant_cc_info[token_owner] = peer_cc_info self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False From 8830069ee2c69698fffde924bd04e6f607afa12e Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 22 Feb 2024 09:14:00 -0500 Subject: [PATCH 06/44] Added CLIENT_QUIT event for CCManager to remove client token. --- nvflare/apis/event_type.py | 1 + nvflare/app_opt/confidential_computing/cc_manager.py | 9 +++++++++ nvflare/private/fed/server/fed_server.py | 2 ++ 3 files changed, 12 insertions(+) diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 942da170c5..bddfdb9818 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -72,6 +72,7 @@ class EventType(object): BEFORE_CLIENT_REGISTER = "_before_client_register" AFTER_CLIENT_REGISTER = "_after_client_register" CLIENT_REGISTERED = "_client_registered" + CLIENT_QUIT = "_client_quit" SYSTEM_BOOTSTRAP = "_system_bootstrap" AUTHORIZE_COMMAND_CHECK = "_authorize_command_check" diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index aa45f7b98a..219d811894 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -114,6 +114,9 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): elif event_type == EventType.CLIENT_REGISTERED: # Server side self._add_client_token(fl_ctx) + elif event_type == EventType.CLIENT_QUIT: + # Server side + self._remove_client_token(fl_ctx) elif event_type == EventType.AUTHORIZE_COMMAND_CHECK: command_to_check = fl_ctx.get_prop(key=FLContextKey.COMMAND_NAME) self.logger.debug(f"Received {command_to_check=}") @@ -157,6 +160,12 @@ def _add_client_token(self, fl_ctx: FLContext): self.logger.info(f"Added CC client: {token_owner}") + def _remove_client_token(self, fl_ctx: FLContext): + # server side + peer_ctx = fl_ctx.get_peer_context() + token_owner = peer_ctx.get_identity_name() + self.participant_cc_info.pop(token_owner) + def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: # both server and client sides self.site_name = fl_ctx.get_identity_name() diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index d07dc99246..df4100e0f6 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -539,6 +539,8 @@ def quit_client(self, request: Message) -> Message: token = client.get_token() self.logout_client(token) + self.engine.fire_event(EventType.CLIENT_QUIT, fl_ctx=fl_ctx) + headers = {CellMessageHeaderKeys.MESSAGE: "Removed client"} return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx) From 9fbdde3eca06d4011661004d41c8c959c30464a7 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 22 Feb 2024 09:31:48 -0500 Subject: [PATCH 07/44] Added _add_client_token client token logging info. --- nvflare/app_opt/confidential_computing/cc_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 219d811894..27cf70b898 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -154,17 +154,18 @@ def _add_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() - peer_cc_info = peer_ctx.get_prop(CC_INFO, {CC_TOKEN: "", CC_TOKEN_VALIDATED: False}) + peer_cc_info = peer_ctx.get_prop(CC_INFO, {CC_TOKEN: ""}) self.participant_cc_info[token_owner] = peer_cc_info self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False - self.logger.info(f"Added CC client: {token_owner}") + self.logger.info(f"Added CC client: {token_owner} token: {peer_cc_info[CC_TOKEN]}") def _remove_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() self.participant_cc_info.pop(token_owner) + self.logger.info(f"Removed CC client: {token_owner}") def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: # both server and client sides From dd9fe74c7433b5354b340f4522e66caa8a89c989 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 22 Feb 2024 09:56:24 -0500 Subject: [PATCH 08/44] Added peer_ctx for client quit. --- nvflare/private/fed/client/communicator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index d3475f0253..7dc6525516 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -292,6 +292,9 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): server's reply to the last message """ + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + shareable = Shareable() + shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) client_name = fl_ctx.get_identity_name() quit_message = new_cell_message( { @@ -299,7 +302,8 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: task_name, - } + }, + shareable ) try: result = self.cell.send_request( From 2f44ceff38c7c89fed43a312ff4d92c08f003544 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 22 Feb 2024 10:05:17 -0500 Subject: [PATCH 09/44] set_peer_context for client quit. --- nvflare/private/fed/server/fed_server.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index df4100e0f6..1bc9b22bca 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -539,6 +539,9 @@ def quit_client(self, request: Message) -> Message: token = client.get_token() self.logout_client(token) + data = request.payload + shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) + fl_ctx.set_peer_context(shared_fl_ctx) self.engine.fire_event(EventType.CLIENT_QUIT, fl_ctx=fl_ctx) headers = {CellMessageHeaderKeys.MESSAGE: "Removed client"} From 0f87eb3e6649c844fc8be537be2f7f1b4b17600c Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Mon, 26 Feb 2024 11:09:04 -0500 Subject: [PATCH 10/44] Changed the AUTHORIZATION_REASON set_prop sticky to False. --- nvflare/app_opt/confidential_computing/cc_manager.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 27cf70b898..60a2ef0b07 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -248,11 +248,11 @@ def _validate_participants_tokens(self, participants) -> str: def _not_authorize_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_REASON, value=reason) - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False) + fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_REASON, value=reason, sticky=False) + fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) def _block_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=reason) - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False) + fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=reason, sticky=False) + fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) From 6ec9ae66c6fa2a780448871fa45ae184e93dc45a Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Mon, 26 Feb 2024 14:22:58 -0500 Subject: [PATCH 11/44] WIP: TokenPundit interface change. --- .../confidential_computing/cc_authorizer.py | 58 ++++++++ .../confidential_computing/cc_manager.py | 40 ++++-- .../confidential_computing/tdx_connector.py | 129 ++++++++++-------- 3 files changed, 158 insertions(+), 69 deletions(-) create mode 100644 nvflare/app_opt/confidential_computing/cc_authorizer.py diff --git a/nvflare/app_opt/confidential_computing/cc_authorizer.py b/nvflare/app_opt/confidential_computing/cc_authorizer.py new file mode 100644 index 0000000000..362000b6a9 --- /dev/null +++ b/nvflare/app_opt/confidential_computing/cc_authorizer.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import os.path + +class TokenPundit: + def can_generate(self) -> bool: + """This indicates if the authorizer can generate a CC token or not. + + Returns: bool + + """ + pass + + def can_verify(self) -> bool: + """This indicates if the authorizer can verify a CC token or not. + + Returns: bool + + """ + pass + + def get_namespace(self) -> str: + """This returns the namespace of the CCAuthorizer. + + Returns: namespace string + + """ + pass + + def generate(self) -> str: + """To generate and return the active CCAuthorizer token. + + Returns: token string + + """ + pass + + def verify(self, token: str) -> bool: + """To return the token verification result. + + Args: + token: bool + + Returns: + + """ + pass diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 60a2ef0b07..6e2e28b151 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError -from nvflare.app_opt.confidential_computing.tdx_connector import TDXCCHelper +from nvflare.app_opt.confidential_computing.cc_authorizer import TokenPundit +from nvflare.app_opt.confidential_computing.tdx_connector import TDXConnector # from .cc_helper import CCHelper @@ -92,7 +94,7 @@ def __init__(self, verifiers: list): """ FLComponent.__init__(self) self.site_name = None - self.helper = None + self.cc_authorizer: TokenPundit = None self.verifiers = verifiers self.my_token = None self.participant_cc_info = {} # used by the Server to keep tokens of all clients @@ -146,7 +148,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): def _prepare_token_for_login(self, fl_ctx: FLContext): # client side if self.my_token is None: - self.my_token = self.helper.get_token() + self.my_token = self.cc_authorizer.generate() cc_info = {CC_TOKEN: self.my_token} fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) @@ -172,13 +174,19 @@ def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: self.site_name = fl_ctx.get_identity_name() workspace_folder = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_site_config_dir() # self.helper = CCHelper(site_name=self.site_name, verifiers=self.verifiers) - self.helper = TDXCCHelper(site_name=self.site_name, - tdx_cli_command="/home/azureuser/TDX/client/tdx-cli/trustauthority-cli", - config_dir=workspace_folder) - ok = self.helper.prepare() - if not ok: - return "failed to attest" - self.my_token = self.helper.get_token() + # self.helper = TDXCCHelper(site_name=self.site_name, + # tdx_cli_command="/home/azureuser/TDX/client/tdx-cli/trustauthority-cli", + # config_dir=workspace_folder) + # ok = self.helper.prepare() + # if not ok: + # return "failed to attest" + + self.cc_authorizer = TDXConnector(tdx_cli_command="/home/azureuser/TDX/client/tdx-cli/trustauthority-cli", + config_dir=workspace_folder) + self.my_token = self.cc_authorizer.generate() + if not self.my_token: + return "failed to get CC token" + self.participant_cc_info[self.site_name] = {CC_TOKEN: self.my_token, CC_TOKEN_VALIDATED: True} return "" @@ -233,7 +241,7 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: def _validate_participants_tokens(self, participants) -> str: self.logger.debug(f"Validating participant tokens {participants=}") - result = self.helper.validate_participants(participants) + result = self._validate_participants(participants) assert isinstance(result, dict) for p in result: self.participant_cc_info[p] = {CC_TOKEN: participants[p], CC_TOKEN_VALIDATED: True} @@ -245,6 +253,16 @@ def _validate_participants_tokens(self, participants) -> str: else: return "" + def _validate_participants(self, participants: Dict[str, str]) -> Dict[str, bool]: + result = {} + if not participants: + return result + for k, v in participants.items(): + if self.cc_authorizer.verify(v): + result[k] = True + self.logger.info(f"CC - results from validating participants' tokens: {result}") + return result + def _not_authorize_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") diff --git a/nvflare/app_opt/confidential_computing/tdx_connector.py b/nvflare/app_opt/confidential_computing/tdx_connector.py index 24a22a5302..d9d4af7a13 100644 --- a/nvflare/app_opt/confidential_computing/tdx_connector.py +++ b/nvflare/app_opt/confidential_computing/tdx_connector.py @@ -11,20 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging + import os -import shlex import subprocess -from typing import Dict -# TDX_CLI_COMMAND = "./trustauthority-cli" +from nvflare.app_opt.confidential_computing.cc_authorizer import TokenPundit + +TDX_NAMESPACE = "tdx_" TDX_CLI_CONFIG = "config.json" TOKEN_FILE = "token.txt" VERIFY_FILE = "verify.txt" ERROR_FILE = "error.txt" -class TDXConnector: +class TDXConnector(TokenPundit): def __init__(self, tdx_cli_command: str, config_dir: str) -> None: super().__init__() self.tdx_cli_command = tdx_cli_command @@ -32,34 +32,32 @@ def __init__(self, tdx_cli_command: str, config_dir: str) -> None: self.config_file = os.path.join(self.config_dir, TDX_CLI_CONFIG) - def get_token(self): + def generate(self) -> str: token_file = os.path.join(self.config_dir, TOKEN_FILE) out = open(token_file, "w") error_file = os.path.join(self.config_dir, ERROR_FILE) err_out = open(error_file, "w") - command = "sudo " + self.tdx_cli_command + " -c " + self.config_file + " token --no-eventlog " - process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, stdout=out, stderr=err_out) - process.wait() + command = ['sudo', self.tdx_cli_command, '-c', self.config_file, 'token', "--no-eventlog"] + subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) - with open(token_file, "r") as f: - token = f.readline() + # with open(token_file, "r") as f: + # token = f.readline() with open(error_file, "r") as f: if 'Error:' in f.read(): - error = True + return "" else: - error = False + with open(token_file, "r") as f: + token = f.readline() + return token - return token, error - - def verify_token(self, token: str): + def verify(self, token: str) -> bool: out = open(os.path.join(self.config_dir, VERIFY_FILE), "w") error_file = os.path.join(self.config_dir, ERROR_FILE) err_out = open(error_file, "w") - command = self.tdx_cli_command + " verify --config " + self.config_file + " --token " + token - process = subprocess.Popen(shlex.split(command, True), preexec_fn=os.setsid, stdout=out, stderr=err_out) - process.wait() + command = [self.tdx_cli_command, "verify", "--config", self.config_file, "--token", token] + subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) # with open(VERIFY_FILE, "r") as f: # result = f.readline() @@ -69,44 +67,59 @@ def verify_token(self, token: str): return True + def can_generate(self) -> bool: + return True + + def can_verify(self) -> bool: + return True + + def get_namespace(self) -> str: + return TDX_NAMESPACE -class TDXCCHelper: + # def generate(self) -> str: + # return super().generate() - def __init__(self, site_name: str, tdx_cli_command: str, config_dir: str) -> None: - super().__init__() - self.site_name = site_name - # self.tdx_cli_command = tdx_cli_command - # self.config_dir = config_dir - self.token = None - - self.tdx_connector = TDXConnector(tdx_cli_command, config_dir) - self.logger = logging.getLogger(self.__class__.__name__) - - def prepare(self) -> bool: - self.token, error = self.tdx_connector.get_token() - self.logger.info(f"site: {self.site_name} got the token: {self.token}") - return not error - - def get_token(self): - return self.token - - def validate_participants(self, participants: Dict[str, str]) -> Dict[str, bool]: - result = {} - if not participants: - return result - for k, v in participants.items(): - if self.tdx_connector.verify_token(v): - result[k] = True - self.logger.info(f"CC - results from validating participants' tokens: {result}") - return result - - -if __name__ == "__main__": - tdx_connector = TDXConnector() - token, error = tdx_connector.get_token() - print("--- Acquire the token ---") - print(token) - - result = tdx_connector.verify_token(token) - print("---- Verify the token ---") - print(result) + # def verify(self, token: str) -> bool: + # return super().verify(token) + + +# class TDXCCHelper: +# +# def __init__(self, site_name: str, tdx_cli_command: str, config_dir: str) -> None: +# super().__init__() +# self.site_name = site_name +# # self.tdx_cli_command = tdx_cli_command +# # self.config_dir = config_dir +# self.token = None +# +# self.tdx_connector = TDXConnector(tdx_cli_command, config_dir) +# self.logger = logging.getLogger(self.__class__.__name__) +# +# def prepare(self) -> bool: +# self.token, error = self.tdx_connector.generate() +# self.logger.info(f"site: {self.site_name} got the token: {self.token}") +# return not error +# +# def get_token(self): +# return self.token +# +# def validate_participants(self, participants: Dict[str, str]) -> Dict[str, bool]: +# result = {} +# if not participants: +# return result +# for k, v in participants.items(): +# if self.tdx_connector.verify(v): +# result[k] = True +# self.logger.info(f"CC - results from validating participants' tokens: {result}") +# return result + + +# if __name__ == "__main__": +# tdx_connector = TDXConnector() +# token = tdx_connector.generate() +# print("--- Acquire the token ---") +# print(token) +# +# result = tdx_connector.verify(token) +# print("---- Verify the token ---") +# print(result) From 044201a8f03b33a3cb9b755ab46582592cb4ce48 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Mon, 26 Feb 2024 15:26:28 -0500 Subject: [PATCH 12/44] WIP: added cc_authorizer_ids config. --- .../confidential_computing/cc_manager.py | 72 ++++---------- .../confidential_computing/gpu_authorizer.py | 94 +++++++++++++++++++ 2 files changed, 111 insertions(+), 55 deletions(-) create mode 100644 nvflare/app_opt/confidential_computing/gpu_authorizer.py diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 6e2e28b151..43fc378404 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -30,7 +30,7 @@ class CCManager(FLComponent): - def __init__(self, verifiers: list): + def __init__(self, cc_authorizer_ids: [str]): """Manage all confidential computing related tasks. This manager does the following tasks: @@ -40,68 +40,25 @@ def __init__(self, verifiers: list): validating all tokens in the entire NVFlare system Args: - verifiers (list): - each element in this list is a dictionary and the keys of dictionary are - "devices", "env", "url", "appraisal_policy_file" and "result_policy_file." - - the values of devices are "gpu" and "cpu" - the values of env are "local" and "test" - currently, valid combination is gpu + local - - url must be an empty string - appraisal_policy_file must point to an existing file - currently supports an empty file only - - result_policy_file must point to an existing file - currently supports the following content only - - .. code-block:: json - - { - "version":"1.0", - "authorization-rules":{ - "x-nv-gpu-available":true, - "x-nv-gpu-attestation-report-available":true, - "x-nv-gpu-info-fetched":true, - "x-nv-gpu-arch-check":true, - "x-nv-gpu-root-cert-available":true, - "x-nv-gpu-cert-chain-verified":true, - "x-nv-gpu-ocsp-cert-chain-verified":true, - "x-nv-gpu-ocsp-signature-verified":true, - "x-nv-gpu-cert-ocsp-nonce-match":true, - "x-nv-gpu-cert-check-complete":true, - "x-nv-gpu-measurement-available":true, - "x-nv-gpu-attestation-report-parsed":true, - "x-nv-gpu-nonce-match":true, - "x-nv-gpu-attestation-report-driver-version-match":true, - "x-nv-gpu-attestation-report-vbios-version-match":true, - "x-nv-gpu-attestation-report-verified":true, - "x-nv-gpu-driver-rim-schema-fetched":true, - "x-nv-gpu-driver-rim-schema-validated":true, - "x-nv-gpu-driver-rim-cert-extracted":true, - "x-nv-gpu-driver-rim-signature-verified":true, - "x-nv-gpu-driver-rim-driver-measurements-available":true, - "x-nv-gpu-driver-vbios-rim-fetched":true, - "x-nv-gpu-vbios-rim-schema-validated":true, - "x-nv-gpu-vbios-rim-cert-extracted":true, - "x-nv-gpu-vbios-rim-signature-verified":true, - "x-nv-gpu-vbios-rim-driver-measurements-available":true, - "x-nv-gpu-vbios-index-conflict":true, - "x-nv-gpu-measurements-match":true - } - } """ FLComponent.__init__(self) self.site_name = None - self.cc_authorizer: TokenPundit = None - self.verifiers = verifiers + self.cc_authorizer_ids = cc_authorizer_ids + self.cc_authorizers = [] self.my_token = None self.participant_cc_info = {} # used by the Server to keep tokens of all clients def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.SYSTEM_BOOTSTRAP: try: + engine = fl_ctx.get_engine() + for id in self.cc_authorizer_ids: + authorizer = engine.get_component(id) + if not isinstance(authorizer, TokenPundit): + raise RuntimeError(f"cc_authorizer_id {id} must be a TokenPundit, but got {authorizer.__class__}") + self.cc_authorizers.append(authorizer) + err = self._prepare_for_attestation(fl_ctx) except: self.log_exception(fl_ctx, "exception in attestation preparation") @@ -181,12 +138,14 @@ def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: # if not ok: # return "failed to attest" - self.cc_authorizer = TDXConnector(tdx_cli_command="/home/azureuser/TDX/client/tdx-cli/trustauthority-cli", - config_dir=workspace_folder) + # self.cc_authorizer = TDXConnector(tdx_cli_command="/home/azureuser/TDX/client/tdx-cli/trustauthority-cli", + # config_dir=workspace_folder) + self.cc_authorizer = self._get_authorizer() self.my_token = self.cc_authorizer.generate() if not self.my_token: return "failed to get CC token" + self.logger.info(f"site: {self.site_name} got the token: {self.my_token}") self.participant_cc_info[self.site_name] = {CC_TOKEN: self.my_token, CC_TOKEN_VALIDATED: True} return "" @@ -274,3 +233,6 @@ def _block_job(self, reason: str, fl_ctx: FLContext): self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=reason, sticky=False) fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) + + def _get_authorizer(self): + return self.cc_authorizers[0] diff --git a/nvflare/app_opt/confidential_computing/gpu_authorizer.py b/nvflare/app_opt/confidential_computing/gpu_authorizer.py new file mode 100644 index 0000000000..3290a3d4fb --- /dev/null +++ b/nvflare/app_opt/confidential_computing/gpu_authorizer.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nvflare.app_opt.confidential_computing.cc_authorizer import TokenPundit + +GPU_NAMESPACE = "x-nv-gpu-" + + +class GPUPundit(TokenPundit): + def __init__(self, verifiers: list) -> None: + """ + + Args: + verifiers (list): + each element in this list is a dictionary and the keys of dictionary are + "devices", "env", "url", "appraisal_policy_file" and "result_policy_file." + + the values of devices are "gpu" and "cpu" + the values of env are "local" and "test" + currently, valid combination is gpu + local + + url must be an empty string + appraisal_policy_file must point to an existing file + currently supports an empty file only + + result_policy_file must point to an existing file + currently supports the following content only + + .. code-block:: json + + { + "version":"1.0", + "authorization-rules":{ + "x-nv-gpu-available":true, + "x-nv-gpu-attestation-report-available":true, + "x-nv-gpu-info-fetched":true, + "x-nv-gpu-arch-check":true, + "x-nv-gpu-root-cert-available":true, + "x-nv-gpu-cert-chain-verified":true, + "x-nv-gpu-ocsp-cert-chain-verified":true, + "x-nv-gpu-ocsp-signature-verified":true, + "x-nv-gpu-cert-ocsp-nonce-match":true, + "x-nv-gpu-cert-check-complete":true, + "x-nv-gpu-measurement-available":true, + "x-nv-gpu-attestation-report-parsed":true, + "x-nv-gpu-nonce-match":true, + "x-nv-gpu-attestation-report-driver-version-match":true, + "x-nv-gpu-attestation-report-vbios-version-match":true, + "x-nv-gpu-attestation-report-verified":true, + "x-nv-gpu-driver-rim-schema-fetched":true, + "x-nv-gpu-driver-rim-schema-validated":true, + "x-nv-gpu-driver-rim-cert-extracted":true, + "x-nv-gpu-driver-rim-signature-verified":true, + "x-nv-gpu-driver-rim-driver-measurements-available":true, + "x-nv-gpu-driver-vbios-rim-fetched":true, + "x-nv-gpu-vbios-rim-schema-validated":true, + "x-nv-gpu-vbios-rim-cert-extracted":true, + "x-nv-gpu-vbios-rim-signature-verified":true, + "x-nv-gpu-vbios-rim-driver-measurements-available":true, + "x-nv-gpu-vbios-index-conflict":true, + "x-nv-gpu-measurements-match":true + } + } + + """ + super().__init__() + self.verifiers = verifiers + + def can_generate(self) -> bool: + return True + + def can_verify(self) -> bool: + return True + + def get_namespace(self) -> str: + return GPU_NAMESPACE + + def generate(self) -> str: + return super().generate() + + def verify(self, token: str) -> bool: + return super().verify(token) From 3c593c06fbf4d14b0c68b26ee6158c581df34d90 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 27 Feb 2024 14:42:16 -0500 Subject: [PATCH 13/44] Added cc_issuer_id for CCManager. --- .../confidential_computing/cc_manager.py | 62 +++++++++++-------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 43fc378404..ce94846959 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -19,18 +19,18 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.app_opt.confidential_computing.cc_authorizer import TokenPundit -from nvflare.app_opt.confidential_computing.tdx_connector import TDXConnector # from .cc_helper import CCHelper PEER_CTX_CC_TOKEN = "_peer_ctx_cc_token" CC_TOKEN = "_cc_token" +CC_NAMESPACE = "_cc_namespace" CC_INFO = "_cc_info" CC_TOKEN_VALIDATED = "_cc_token_validated" class CCManager(FLComponent): - def __init__(self, cc_authorizer_ids: [str]): + def __init__(self, cc_issuer_id: str, cc_verifier_ids: [str]): """Manage all confidential computing related tasks. This manager does the following tasks: @@ -44,20 +44,17 @@ def __init__(self, cc_authorizer_ids: [str]): """ FLComponent.__init__(self) self.site_name = None - self.cc_authorizer_ids = cc_authorizer_ids - self.cc_authorizers = [] + self.cc_issuer_id = cc_issuer_id + self.cc_verifier_ids = cc_verifier_ids + self.cc_issuer = None + self.cc_verifiers = {} self.my_token = None self.participant_cc_info = {} # used by the Server to keep tokens of all clients def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.SYSTEM_BOOTSTRAP: try: - engine = fl_ctx.get_engine() - for id in self.cc_authorizer_ids: - authorizer = engine.get_component(id) - if not isinstance(authorizer, TokenPundit): - raise RuntimeError(f"cc_authorizer_id {id} must be a TokenPundit, but got {authorizer.__class__}") - self.cc_authorizers.append(authorizer) + self._setup_cc_authorizers(fl_ctx) err = self._prepare_for_attestation(fl_ctx) except: @@ -102,18 +99,30 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): # Server side fl_ctx.remove_prop(PEER_CTX_CC_TOKEN) + def _setup_cc_authorizers(self, fl_ctx): + engine = fl_ctx.get_engine() + self.cc_issuer = engine.get_component(self.cc_issuer_id) + if not isinstance(self.cc_issuer, TokenPundit): + raise RuntimeError(f"cc_authorizer_id {self.cc_issuer_id} must be a TokenPundit, but got {self.cc_issuer.__class__}") + + for id in self.cc_verifier_ids: + authorizer = engine.get_component(id) + if not isinstance(authorizer, TokenPundit): + raise RuntimeError(f"cc_authorizer_id {id} must be a TokenPundit, but got {authorizer.__class__}") + self.cc_verifiers[authorizer.get_namespace()] = authorizer + def _prepare_token_for_login(self, fl_ctx: FLContext): # client side if self.my_token is None: - self.my_token = self.cc_authorizer.generate() - cc_info = {CC_TOKEN: self.my_token} + self.my_token = self.cc_issuer.generate() + cc_info = {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace()} fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) def _add_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() - peer_cc_info = peer_ctx.get_prop(CC_INFO, {CC_TOKEN: ""}) + peer_cc_info = peer_ctx.get_prop(CC_INFO, {CC_TOKEN: "", CC_NAMESPACE: ""}) self.participant_cc_info[token_owner] = peer_cc_info self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False @@ -140,13 +149,14 @@ def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: # self.cc_authorizer = TDXConnector(tdx_cli_command="/home/azureuser/TDX/client/tdx-cli/trustauthority-cli", # config_dir=workspace_folder) - self.cc_authorizer = self._get_authorizer() - self.my_token = self.cc_authorizer.generate() + + # self.cc_authorizer = self._get_authorizer() + self.my_token = self.cc_issuer.generate() if not self.my_token: return "failed to get CC token" self.logger.info(f"site: {self.site_name} got the token: {self.my_token}") - self.participant_cc_info[self.site_name] = {CC_TOKEN: self.my_token, CC_TOKEN_VALIDATED: True} + self.participant_cc_info[self.site_name] = {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace(), CC_TOKEN_VALIDATED: True} return "" def _client_to_check_participant_token(self, fl_ctx: FLContext) -> str: @@ -176,7 +186,8 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if not isinstance(participants, list): return f"bad value for {FLContextKey.JOB_PARTICIPANTS} in fl_ctx: expect list bot got {type(participants)}" - participant_tokens = {self.site_name: self.my_token} + # participant_tokens = {self.site_name: self.my_token, CC_VERIFIER: self.cc_issuer} + participant_tokens = {self.site_name: self.participant_cc_info[self.site_name]} for p in participants: assert isinstance(p, str) if p == self.site_name: @@ -184,9 +195,10 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if p not in self.participant_cc_info: return f"no token available for participant {p}" if self.participant_cc_info.get(p): - participant_tokens[p] = self.participant_cc_info[p][CC_TOKEN] + # participant_tokens[p] = self.participant_cc_info[p][CC_TOKEN] + participant_tokens[p] = self.participant_cc_info[p] else: - participant_tokens[p] = "" + participant_tokens[p] = {} err = self._validate_participants_tokens(participant_tokens) if err: @@ -212,13 +224,16 @@ def _validate_participants_tokens(self, participants) -> str: else: return "" - def _validate_participants(self, participants: Dict[str, str]) -> Dict[str, bool]: + def _validate_participants(self, participants: Dict[str, {}]) -> Dict[str, bool]: result = {} if not participants: return result for k, v in participants.items(): - if self.cc_authorizer.verify(v): - result[k] = True + token = v.get(CC_TOKEN, "") + verifier = self.cc_verifiers.get(v.get(CC_NAMESPACE, ""), None) + if verifier: + if verifier.verify(token): + result[k] = True self.logger.info(f"CC - results from validating participants' tokens: {result}") return result @@ -233,6 +248,3 @@ def _block_job(self, reason: str, fl_ctx: FLContext): self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=reason, sticky=False) fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) - - def _get_authorizer(self): - return self.cc_authorizers[0] From 7f74d68a2ae1bf1d69fdc5a7eea9cd6378f600df Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 28 Feb 2024 10:19:17 -0500 Subject: [PATCH 14/44] renamed the TokenPundit to CCAutorizer. --- .../confidential_computing/cc_authorizer.py | 2 +- .../confidential_computing/cc_manager.py | 37 ++++--------- .../confidential_computing/gpu_authorizer.py | 4 +- .../confidential_computing/tdx_connector.py | 52 +------------------ 4 files changed, 15 insertions(+), 80 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_authorizer.py b/nvflare/app_opt/confidential_computing/cc_authorizer.py index 362000b6a9..fb1fc525bb 100644 --- a/nvflare/app_opt/confidential_computing/cc_authorizer.py +++ b/nvflare/app_opt/confidential_computing/cc_authorizer.py @@ -13,7 +13,7 @@ # limitations under the License. # import os.path -class TokenPundit: +class CCAuthorizer: def can_generate(self) -> bool: """This indicates if the authorizer can generate a CC token or not. diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index ce94846959..0de3d78fce 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -18,9 +18,7 @@ from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError -from nvflare.app_opt.confidential_computing.cc_authorizer import TokenPundit - -# from .cc_helper import CCHelper +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer PEER_CTX_CC_TOKEN = "_peer_ctx_cc_token" CC_TOKEN = "_cc_token" @@ -102,13 +100,13 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): def _setup_cc_authorizers(self, fl_ctx): engine = fl_ctx.get_engine() self.cc_issuer = engine.get_component(self.cc_issuer_id) - if not isinstance(self.cc_issuer, TokenPundit): - raise RuntimeError(f"cc_authorizer_id {self.cc_issuer_id} must be a TokenPundit, but got {self.cc_issuer.__class__}") + if not isinstance(self.cc_issuer, CCAuthorizer): + raise RuntimeError(f"cc_authorizer_id {self.cc_issuer_id} must be a CCAuthorizer, but got {self.cc_issuer.__class__}") - for id in self.cc_verifier_ids: - authorizer = engine.get_component(id) - if not isinstance(authorizer, TokenPundit): - raise RuntimeError(f"cc_authorizer_id {id} must be a TokenPundit, but got {authorizer.__class__}") + for v_id in self.cc_verifier_ids: + authorizer = engine.get_component(v_id) + if not isinstance(authorizer, CCAuthorizer): + raise RuntimeError(f"cc_authorizer_id {v_id} must be a CCAuthorizer, but got {authorizer.__class__}") self.cc_verifiers[authorizer.get_namespace()] = authorizer def _prepare_token_for_login(self, fl_ctx: FLContext): @@ -139,18 +137,6 @@ def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: # both server and client sides self.site_name = fl_ctx.get_identity_name() workspace_folder = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_site_config_dir() - # self.helper = CCHelper(site_name=self.site_name, verifiers=self.verifiers) - # self.helper = TDXCCHelper(site_name=self.site_name, - # tdx_cli_command="/home/azureuser/TDX/client/tdx-cli/trustauthority-cli", - # config_dir=workspace_folder) - # ok = self.helper.prepare() - # if not ok: - # return "failed to attest" - - # self.cc_authorizer = TDXConnector(tdx_cli_command="/home/azureuser/TDX/client/tdx-cli/trustauthority-cli", - # config_dir=workspace_folder) - - # self.cc_authorizer = self._get_authorizer() self.my_token = self.cc_issuer.generate() if not self.my_token: return "failed to get CC token" @@ -186,7 +172,6 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if not isinstance(participants, list): return f"bad value for {FLContextKey.JOB_PARTICIPANTS} in fl_ctx: expect list bot got {type(participants)}" - # participant_tokens = {self.site_name: self.my_token, CC_VERIFIER: self.cc_issuer} participant_tokens = {self.site_name: self.participant_cc_info[self.site_name]} for p in participants: assert isinstance(p, str) @@ -195,7 +180,6 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if p not in self.participant_cc_info: return f"no token available for participant {p}" if self.participant_cc_info.get(p): - # participant_tokens[p] = self.participant_cc_info[p][CC_TOKEN] participant_tokens[p] = self.participant_cc_info[p] else: participant_tokens[p] = {} @@ -224,16 +208,15 @@ def _validate_participants_tokens(self, participants) -> str: else: return "" - def _validate_participants(self, participants: Dict[str, {}]) -> Dict[str, bool]: + def _validate_participants(self, participants: Dict[str, Dict[str, str]]) -> Dict[str, bool]: result = {} if not participants: return result for k, v in participants.items(): token = v.get(CC_TOKEN, "") verifier = self.cc_verifiers.get(v.get(CC_NAMESPACE, ""), None) - if verifier: - if verifier.verify(token): - result[k] = True + if verifier and verifier.verify(token): + result[k] = True self.logger.info(f"CC - results from validating participants' tokens: {result}") return result diff --git a/nvflare/app_opt/confidential_computing/gpu_authorizer.py b/nvflare/app_opt/confidential_computing/gpu_authorizer.py index 3290a3d4fb..73cf8c7682 100644 --- a/nvflare/app_opt/confidential_computing/gpu_authorizer.py +++ b/nvflare/app_opt/confidential_computing/gpu_authorizer.py @@ -13,12 +13,12 @@ # limitations under the License. -from nvflare.app_opt.confidential_computing.cc_authorizer import TokenPundit +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer GPU_NAMESPACE = "x-nv-gpu-" -class GPUPundit(TokenPundit): +class GPUAuthorizer(CCAuthorizer): def __init__(self, verifiers: list) -> None: """ diff --git a/nvflare/app_opt/confidential_computing/tdx_connector.py b/nvflare/app_opt/confidential_computing/tdx_connector.py index d9d4af7a13..2cb657396c 100644 --- a/nvflare/app_opt/confidential_computing/tdx_connector.py +++ b/nvflare/app_opt/confidential_computing/tdx_connector.py @@ -15,7 +15,7 @@ import os import subprocess -from nvflare.app_opt.confidential_computing.cc_authorizer import TokenPundit +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer TDX_NAMESPACE = "tdx_" TDX_CLI_CONFIG = "config.json" @@ -24,7 +24,7 @@ ERROR_FILE = "error.txt" -class TDXConnector(TokenPundit): +class TDXConnector(CCAuthorizer): def __init__(self, tdx_cli_command: str, config_dir: str) -> None: super().__init__() self.tdx_cli_command = tdx_cli_command @@ -75,51 +75,3 @@ def can_verify(self) -> bool: def get_namespace(self) -> str: return TDX_NAMESPACE - - # def generate(self) -> str: - # return super().generate() - - # def verify(self, token: str) -> bool: - # return super().verify(token) - - -# class TDXCCHelper: -# -# def __init__(self, site_name: str, tdx_cli_command: str, config_dir: str) -> None: -# super().__init__() -# self.site_name = site_name -# # self.tdx_cli_command = tdx_cli_command -# # self.config_dir = config_dir -# self.token = None -# -# self.tdx_connector = TDXConnector(tdx_cli_command, config_dir) -# self.logger = logging.getLogger(self.__class__.__name__) -# -# def prepare(self) -> bool: -# self.token, error = self.tdx_connector.generate() -# self.logger.info(f"site: {self.site_name} got the token: {self.token}") -# return not error -# -# def get_token(self): -# return self.token -# -# def validate_participants(self, participants: Dict[str, str]) -> Dict[str, bool]: -# result = {} -# if not participants: -# return result -# for k, v in participants.items(): -# if self.tdx_connector.verify(v): -# result[k] = True -# self.logger.info(f"CC - results from validating participants' tokens: {result}") -# return result - - -# if __name__ == "__main__": -# tdx_connector = TDXConnector() -# token = tdx_connector.generate() -# print("--- Acquire the token ---") -# print(token) -# -# result = tdx_connector.verify(token) -# print("---- Verify the token ---") -# print(result) From ff5555465c17999e5301b4bec007402ce9f7cb91 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 29 Feb 2024 10:08:05 -0500 Subject: [PATCH 15/44] Added CC token adding through client heartbeat. --- nvflare/apis/event_type.py | 2 ++ .../confidential_computing/cc_manager.py | 23 +++++++++++-------- nvflare/private/fed/client/communicator.py | 9 +++++++- nvflare/private/fed/server/fed_server.py | 5 ++++ 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index bddfdb9818..4c9ff1856a 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -74,6 +74,8 @@ class EventType(object): CLIENT_REGISTERED = "_client_registered" CLIENT_QUIT = "_client_quit" SYSTEM_BOOTSTRAP = "_system_bootstrap" + BEFORE_CLIENT_HEARTBEAT = "_before_client_heartbeat" + AFTER_CLIENT_HEARTBEAT = "_after_client_heartbeat" AUTHORIZE_COMMAND_CHECK = "_authorize_command_check" BEFORE_BUILD_COMPONENT = "_before_build_component" diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 0de3d78fce..bf4a77ecf2 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -62,10 +62,10 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): if err: self.log_critical(fl_ctx, err, fire_event=False) raise UnsafeComponentError(err) - elif event_type == EventType.BEFORE_CLIENT_REGISTER: + elif event_type == EventType.BEFORE_CLIENT_REGISTER or event_type == EventType.BEFORE_CLIENT_HEARTBEAT: # On client side self._prepare_token_for_login(fl_ctx) - elif event_type == EventType.CLIENT_REGISTERED: + elif event_type == EventType.CLIENT_REGISTERED or event_type == EventType.AFTER_CLIENT_HEARTBEAT: # Server side self._add_client_token(fl_ctx) elif event_type == EventType.CLIENT_QUIT: @@ -110,9 +110,10 @@ def _setup_cc_authorizers(self, fl_ctx): self.cc_verifiers[authorizer.get_namespace()] = authorizer def _prepare_token_for_login(self, fl_ctx: FLContext): - # client side - if self.my_token is None: + # client side, if token expired then generate a new one + if not self.cc_issuer.verify(self.my_token): self.my_token = self.cc_issuer.generate() + self.logger.info(f"site: {self.site_name} got a new CC token: {self.my_token}") cc_info = {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace()} fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) @@ -121,10 +122,11 @@ def _add_client_token(self, fl_ctx: FLContext): peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() peer_cc_info = peer_ctx.get_prop(CC_INFO, {CC_TOKEN: "", CC_NAMESPACE: ""}) - self.participant_cc_info[token_owner] = peer_cc_info - self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False - - self.logger.info(f"Added CC client: {token_owner} token: {peer_cc_info[CC_TOKEN]}") + old_cc_info = self.participant_cc_info.get(token_owner) + if not old_cc_info or old_cc_info.get(CC_TOKEN) != peer_cc_info[CC_TOKEN]: + self.participant_cc_info[token_owner] = peer_cc_info + self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False + self.logger.info(f"Added CC client: {token_owner} token: {peer_cc_info[CC_TOKEN]}") def _remove_client_token(self, fl_ctx: FLContext): # server side @@ -172,7 +174,10 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if not isinstance(participants, list): return f"bad value for {FLContextKey.JOB_PARTICIPANTS} in fl_ctx: expect list bot got {type(participants)}" - participant_tokens = {self.site_name: self.participant_cc_info[self.site_name]} + # if server token expired, then generates a new one + if not self.cc_issuer.verify(self.my_token): + self.my_token = self.cc_issuer.generate() + participant_tokens = {self.site_name: {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace()}} for p in participants: assert isinstance(p, str) if p == self.site_name: diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 7dc6525516..f4d3253a16 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -17,6 +17,7 @@ import time from typing import List, Optional +from nvflare.apis.event_type import EventType from nvflare.apis.filter import Filter from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_constant import ReturnCode as ShareableRC @@ -332,6 +333,11 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C heartbeats_log_interval = 10 while not self.heartbeat_done: try: + engine.fire_event(EventType.BEFORE_CLIENT_HEARTBEAT, fl_ctx) + shareable = Shareable() + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) + job_ids = engine.get_all_job_ids() heartbeat_message = new_cell_message( { @@ -340,7 +346,8 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.PROJECT_NAME: task_name, CellMessageHeaderKeys.JOB_IDS: job_ids, - } + }, + shareable ) try: diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 1bc9b22bca..e87c04e6a4 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -577,6 +577,11 @@ def client_heartbeat(self, request: Message) -> Message: if error is not None: return make_cellnet_reply(rc=F3ReturnCode.COMM_ERROR, error=error) + data = request.payload + shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) + fl_ctx.set_peer_context(shared_fl_ctx) + self.engine.fire_event(EventType.AFTER_CLIENT_HEARTBEAT, fl_ctx=fl_ctx) + token = request.get_header(CellMessageHeaderKeys.TOKEN) client_name = request.get_header(CellMessageHeaderKeys.CLIENT_NAME) From 2c5f3bd84a18a04ec137d00d247f9dad2db2a7ae Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 29 Feb 2024 11:59:38 -0500 Subject: [PATCH 16/44] Added function to stop current running job if CC verify fail. --- .../confidential_computing/cc_manager.py | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index bf4a77ecf2..215a79323b 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -15,7 +15,7 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey +from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer @@ -128,6 +128,19 @@ def _add_client_token(self, fl_ctx: FLContext): self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False self.logger.info(f"Added CC client: {token_owner} token: {peer_cc_info[CC_TOKEN]}") + self._verify_running_jobs(fl_ctx) + + def _verify_running_jobs(self, fl_ctx): + engine = fl_ctx.get_engine() + run_processes = engine.run_processes + running_jobs = list(run_processes.keys()) + for job_id in running_jobs: + participants = run_processes[job_id].get(RunProcessKey.PARTICIPANTS) + participant_tokens = {} + err = self._verify_participants(participants, participant_tokens) + if err: + engine.job_runner.stop_run(job_id, fl_ctx) + def _remove_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() @@ -174,10 +187,23 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if not isinstance(participants, list): return f"bad value for {FLContextKey.JOB_PARTICIPANTS} in fl_ctx: expect list bot got {type(participants)}" + participant_tokens = {} + err = self._verify_participants(participants, participant_tokens) + if err: + return err + + for p in participant_tokens: + self.participant_cc_info[p][CC_TOKEN_VALIDATED] = True + fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=True, private=False) + self.logger.info(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") + return "" + + def _verify_participants(self, participants, participant_tokens): # if server token expired, then generates a new one if not self.cc_issuer.verify(self.my_token): self.my_token = self.cc_issuer.generate() - participant_tokens = {self.site_name: {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace()}} + # participant_tokens = {self.site_name: {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace()}} + participant_tokens[self.site_name] = {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace()} for p in participants: assert isinstance(p, str) if p == self.site_name: @@ -188,16 +214,7 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: participant_tokens[p] = self.participant_cc_info[p] else: participant_tokens[p] = {} - - err = self._validate_participants_tokens(participant_tokens) - if err: - return err - - for p in participant_tokens: - self.participant_cc_info[p][CC_TOKEN_VALIDATED] = True - fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=True, private=False) - self.logger.info(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") - return "" + return self._validate_participants_tokens(participant_tokens) def _validate_participants_tokens(self, participants) -> str: self.logger.debug(f"Validating participant tokens {participants=}") From 2f383d65aa81c7048ad47e17681d30a80ed7349d Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 29 Feb 2024 14:26:50 -0500 Subject: [PATCH 17/44] if CC failed to get toke, don't allow the system to start. --- .../app_opt/confidential_computing/cc_manager.py | 13 +++++++++++-- nvflare/private/fed/app/client/client_train.py | 6 ++++++ nvflare/private/fed/app/deployer/server_deployer.py | 6 ++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 215a79323b..5f0d6eaa8c 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading from typing import Dict from nvflare.apis.event_type import EventType @@ -49,6 +50,8 @@ def __init__(self, cc_issuer_id: str, cc_verifier_ids: [str]): self.my_token = None self.participant_cc_info = {} # used by the Server to keep tokens of all clients + self.lock = threading.Lock() + def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.SYSTEM_BOOTSTRAP: try: @@ -128,18 +131,24 @@ def _add_client_token(self, fl_ctx: FLContext): self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False self.logger.info(f"Added CC client: {token_owner} token: {peer_cc_info[CC_TOKEN]}") - self._verify_running_jobs(fl_ctx) + with self.lock: + self._verify_running_jobs(fl_ctx) def _verify_running_jobs(self, fl_ctx): engine = fl_ctx.get_engine() run_processes = engine.run_processes running_jobs = list(run_processes.keys()) for job_id in running_jobs: - participants = run_processes[job_id].get(RunProcessKey.PARTICIPANTS) + job_participants = run_processes[job_id].get(RunProcessKey.PARTICIPANTS) + participants = [] + for _, client in job_participants.items(): + participants.append(client.name) + participant_tokens = {} err = self._verify_participants(participants, participant_tokens) if err: engine.job_runner.stop_run(job_id, fl_ctx) + self.logger.info(f"Stop Job: {job_id} with CC verification error: {err} ") def _remove_client_token(self, fl_ctx: FLContext): # server side diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index aa68ad5adb..54b8181a8a 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -21,6 +21,7 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey, JobConstants, SiteType, WorkspaceConstants +from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.apis.workspace import Workspace from nvflare.fuel.common.excepts import ConfigError from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm @@ -111,6 +112,11 @@ def main(args): fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) + exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) + for _, exception in exceptions.items(): + if isinstance(exception, UnsafeComponentError): + raise RuntimeError(exception) + client_engine.fire_event(EventType.BEFORE_CLIENT_REGISTER, fl_ctx) federated_client.register(fl_ctx) fl_ctx.set_prop(FLContextKey.CLIENT_TOKEN, federated_client.token) diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index 899832dcba..82b05863a4 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -17,6 +17,7 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import SystemComponents, FLContextKey +from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.apis.workspace import Workspace from nvflare.private.fed.server.fed_server import FederatedServer from nvflare.private.fed.server.job_runner import JobRunner @@ -122,6 +123,11 @@ def deploy(self, args): fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) services.engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) + exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) + for _, exception in exceptions.items(): + if isinstance(exception, UnsafeComponentError): + raise RuntimeError(exception) + threading.Thread(target=self._start_job_runner, args=[job_runner, fl_ctx]).start() services.status = ServerStatus.STARTED From ea5ae6103a9f332014972dcf9b30156435b2f58e Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 29 Feb 2024 14:30:59 -0500 Subject: [PATCH 18/44] Added exceptions None check. --- nvflare/private/fed/app/client/client_train.py | 7 ++++--- nvflare/private/fed/app/deployer/server_deployer.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index 54b8181a8a..adb5a4fac5 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -113,9 +113,10 @@ def main(args): client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) - for _, exception in exceptions.items(): - if isinstance(exception, UnsafeComponentError): - raise RuntimeError(exception) + if exceptions: + for _, exception in exceptions.items(): + if isinstance(exception, UnsafeComponentError): + raise RuntimeError(exception) client_engine.fire_event(EventType.BEFORE_CLIENT_REGISTER, fl_ctx) federated_client.register(fl_ctx) diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index 82b05863a4..a2da643518 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -124,9 +124,10 @@ def deploy(self, args): services.engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) - for _, exception in exceptions.items(): - if isinstance(exception, UnsafeComponentError): - raise RuntimeError(exception) + if exceptions: + for _, exception in exceptions.items(): + if isinstance(exception, UnsafeComponentError): + raise RuntimeError(exception) threading.Thread(target=self._start_job_runner, args=[job_runner, fl_ctx]).start() services.status = ServerStatus.STARTED From 13e0b6b90671330eff6a93778520b2cbe1473164 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 1 Mar 2024 10:17:28 -0500 Subject: [PATCH 19/44] Address the client side CC check before job scheduled. --- .../confidential_computing/cc_manager.py | 27 ++++++++++--------- nvflare/private/fed/client/scheduler_cmds.py | 4 ++- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 5f0d6eaa8c..87cde11929 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -74,20 +74,21 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): elif event_type == EventType.CLIENT_QUIT: # Server side self._remove_client_token(fl_ctx) - elif event_type == EventType.AUTHORIZE_COMMAND_CHECK: - command_to_check = fl_ctx.get_prop(key=FLContextKey.COMMAND_NAME) - self.logger.debug(f"Received {command_to_check=}") - if command_to_check == AdminCommandNames.CHECK_RESOURCES: - try: - err = self._client_to_check_participant_token(fl_ctx) - except: - self.log_exception(fl_ctx, "exception in validating participants") - err = "Participants unable to meet client CC requirements" - finally: - if err: - self._not_authorize_job(err, fl_ctx) + elif event_type == EventType.BEFORE_CHECK_RESOURCE_MANAGER: + # Client Side, job scheduler check resource + # command_to_check = fl_ctx.get_prop(key=FLContextKey.COMMAND_NAME) + # self.logger.debug(f"Received {command_to_check=}") + # if command_to_check == AdminCommandNames.CHECK_RESOURCES: + try: + err = self._client_to_check_participant_token(fl_ctx) + except: + self.log_exception(fl_ctx, "exception in validating participants") + err = "Participants unable to meet client CC requirements" + finally: + if err: + self._not_authorize_job(err, fl_ctx) elif event_type == EventType.BEFORE_CHECK_CLIENT_RESOURCES: - # Server side + # Server side, job scheduler check client resources try: err = self._server_to_check_client_token(fl_ctx) except: diff --git a/nvflare/private/fed/client/scheduler_cmds.py b/nvflare/private/fed/client/scheduler_cmds.py index 58828121bb..e6e7a66ba4 100644 --- a/nvflare/private/fed/client/scheduler_cmds.py +++ b/nvflare/private/fed/client/scheduler_cmds.py @@ -18,7 +18,7 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey, ReturnCode, SystemComponents from nvflare.apis.resource_manager_spec import ResourceConsumerSpec, ResourceManagerSpec -from nvflare.apis.shareable import Shareable +from nvflare.apis.shareable import Shareable, ReservedHeaderKey from nvflare.private.admin_defs import Message from nvflare.private.defs import ERROR_MSG_PREFIX, RequestHeader, SysCommandTopic, TrainingTopic from nvflare.private.fed.client.admin import RequestProcessor @@ -68,6 +68,8 @@ def process(self, req: Message, app_ctx) -> Message: fl_ctx.set_prop(key=FLContextKey.CLIENT_RESOURCE_SPECS, value=resource_spec, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, job_id, private=True, sticky=False) + shared_fl_ctx = req.get_header(ReservedHeaderKey.PEER_PROPS) + fl_ctx.set_peer_context(shared_fl_ctx) engine.fire_event(EventType.BEFORE_CHECK_RESOURCE_MANAGER, fl_ctx) block_reason = fl_ctx.get_prop(FLContextKey.JOB_BLOCK_REASON) From 68d8d9101c63999b01edc8d1a57b712a15e8643a Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 1 Mar 2024 10:58:58 -0500 Subject: [PATCH 20/44] fixed the PEER_FL_CONTEXT error. --- nvflare/private/fed/client/scheduler_cmds.py | 6 +++--- nvflare/private/fed/server/admin.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/nvflare/private/fed/client/scheduler_cmds.py b/nvflare/private/fed/client/scheduler_cmds.py index e6e7a66ba4..8ec3e0d855 100644 --- a/nvflare/private/fed/client/scheduler_cmds.py +++ b/nvflare/private/fed/client/scheduler_cmds.py @@ -16,7 +16,7 @@ from typing import List from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLContextKey, ReturnCode, SystemComponents +from nvflare.apis.fl_constant import FLContextKey, ReturnCode, SystemComponents, ServerCommandKey from nvflare.apis.resource_manager_spec import ResourceConsumerSpec, ResourceManagerSpec from nvflare.apis.shareable import Shareable, ReservedHeaderKey from nvflare.private.admin_defs import Message @@ -68,7 +68,7 @@ def process(self, req: Message, app_ctx) -> Message: fl_ctx.set_prop(key=FLContextKey.CLIENT_RESOURCE_SPECS, value=resource_spec, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, job_id, private=True, sticky=False) - shared_fl_ctx = req.get_header(ReservedHeaderKey.PEER_PROPS) + shared_fl_ctx = req.get_header(ServerCommandKey.PEER_FL_CONTEXT) fl_ctx.set_peer_context(shared_fl_ctx) engine.fire_event(EventType.BEFORE_CHECK_RESOURCE_MANAGER, fl_ctx) @@ -80,7 +80,7 @@ def process(self, req: Message, app_ctx) -> Message: is_resource_enough, token = resource_manager.check_resources( resource_requirement=resource_spec, fl_ctx=fl_ctx ) - except Exception: + except Exception as e: result.set_return_code(ReturnCode.EXECUTION_EXCEPTION) result.set_header(ShareableHeader.IS_RESOURCE_ENOUGH, is_resource_enough) diff --git a/nvflare/private/fed/server/admin.py b/nvflare/private/fed/server/admin.py index eec732013a..7581502d81 100644 --- a/nvflare/private/fed/server/admin.py +++ b/nvflare/private/fed/server/admin.py @@ -17,7 +17,9 @@ from typing import List, Optional from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import ServerCommandKey from nvflare.apis.shareable import ReservedHeaderKey +from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.cellnet.net_manager import NetManager @@ -276,7 +278,8 @@ def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [Cl for _, request in requests.items(): with self.sai.new_context() as fl_ctx: self.sai.fire_event(EventType.BEFORE_SEND_ADMIN_COMMAND, fl_ctx) - request.set_header(ReservedHeaderKey.PEER_PROPS, copy.deepcopy(fl_ctx.get_all_public_props())) + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + request.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) return send_requests( cell=self.cell, From b9942e3740978d898a0500e08d9638e131250237 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Sat, 2 Mar 2024 13:31:12 -0500 Subject: [PATCH 21/44] Added CCManager support to have multiple cc_issuers. --- .../confidential_computing/cc_manager.py | 130 +++++++++++------- 1 file changed, 83 insertions(+), 47 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 87cde11929..6a0c8a3e15 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -12,24 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. import threading -from typing import Dict +from typing import Dict, List from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey +from nvflare.apis.fl_constant import FLContextKey, RunProcessKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer PEER_CTX_CC_TOKEN = "_peer_ctx_cc_token" CC_TOKEN = "_cc_token" +CC_ISSUER = "_cc_issuer" CC_NAMESPACE = "_cc_namespace" CC_INFO = "_cc_info" CC_TOKEN_VALIDATED = "_cc_token_validated" class CCManager(FLComponent): - def __init__(self, cc_issuer_id: str, cc_verifier_ids: [str]): + def __init__(self, cc_issuer_ids: str, cc_verifier_ids: [str]): """Manage all confidential computing related tasks. This manager does the following tasks: @@ -43,11 +44,10 @@ def __init__(self, cc_issuer_id: str, cc_verifier_ids: [str]): """ FLComponent.__init__(self) self.site_name = None - self.cc_issuer_id = cc_issuer_id + self.cc_issuer_ids = cc_issuer_ids self.cc_verifier_ids = cc_verifier_ids - self.cc_issuer = None + self.cc_issuers = [] self.cc_verifiers = {} - self.my_token = None self.participant_cc_info = {} # used by the Server to keep tokens of all clients self.lock = threading.Lock() @@ -75,10 +75,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): # Server side self._remove_client_token(fl_ctx) elif event_type == EventType.BEFORE_CHECK_RESOURCE_MANAGER: - # Client Side, job scheduler check resource - # command_to_check = fl_ctx.get_prop(key=FLContextKey.COMMAND_NAME) - # self.logger.debug(f"Received {command_to_check=}") - # if command_to_check == AdminCommandNames.CHECK_RESOURCES: + # Client side, check resources before job scheduled try: err = self._client_to_check_participant_token(fl_ctx) except: @@ -103,34 +100,47 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): def _setup_cc_authorizers(self, fl_ctx): engine = fl_ctx.get_engine() - self.cc_issuer = engine.get_component(self.cc_issuer_id) - if not isinstance(self.cc_issuer, CCAuthorizer): - raise RuntimeError(f"cc_authorizer_id {self.cc_issuer_id} must be a CCAuthorizer, but got {self.cc_issuer.__class__}") + for i_id in self.cc_issuer_ids: + issuer = engine.get_component(i_id) + if not isinstance(issuer, CCAuthorizer): + raise RuntimeError(f"cc_issuer_id {i_id} must be a CCAuthorizer, but got {issuer.__class__}") + self.cc_issuers.append(issuer) for v_id in self.cc_verifier_ids: authorizer = engine.get_component(v_id) if not isinstance(authorizer, CCAuthorizer): raise RuntimeError(f"cc_authorizer_id {v_id} must be a CCAuthorizer, but got {authorizer.__class__}") - self.cc_verifiers[authorizer.get_namespace()] = authorizer + namespace = authorizer.get_namespace() + if namespace in self.cc_verifiers.keys(): + raise RuntimeError(f"Authorizer with namespace: {namespace} already exist.") + self.cc_verifiers[namespace] = authorizer def _prepare_token_for_login(self, fl_ctx: FLContext): # client side, if token expired then generate a new one - if not self.cc_issuer.verify(self.my_token): - self.my_token = self.cc_issuer.generate() - self.logger.info(f"site: {self.site_name} got a new CC token: {self.my_token}") - cc_info = {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace()} + self._handle_expired_tokens() + + site_cc_info = self.participant_cc_info[self.site_name] + cc_info = self._get_participant_tokens(site_cc_info) fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) def _add_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() - peer_cc_info = peer_ctx.get_prop(CC_INFO, {CC_TOKEN: "", CC_NAMESPACE: ""}) + peer_cc_info = peer_ctx.get_prop(CC_INFO, [{CC_TOKEN: "", CC_NAMESPACE: ""}]) + new_tokens = [] + for i in peer_cc_info: + new_tokens.append(i[CC_TOKEN]) + old_cc_info = self.participant_cc_info.get(token_owner) - if not old_cc_info or old_cc_info.get(CC_TOKEN) != peer_cc_info[CC_TOKEN]: + old_tokens = [] + if old_cc_info: + for i in old_cc_info: + old_tokens.append(i[CC_TOKEN]) + + if not old_cc_info or set(new_tokens) != set(old_tokens): self.participant_cc_info[token_owner] = peer_cc_info - self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False - self.logger.info(f"Added CC client: {token_owner} token: {peer_cc_info[CC_TOKEN]}") + self.logger.info(f"Added CC client: {token_owner} tokens: {peer_cc_info}") with self.lock: self._verify_running_jobs(fl_ctx) @@ -162,12 +172,19 @@ def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: # both server and client sides self.site_name = fl_ctx.get_identity_name() workspace_folder = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_site_config_dir() - self.my_token = self.cc_issuer.generate() - if not self.my_token: - return "failed to get CC token" - self.logger.info(f"site: {self.site_name} got the token: {self.my_token}") - self.participant_cc_info[self.site_name] = {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace(), CC_TOKEN_VALIDATED: True} + self.participant_cc_info[self.site_name] = [] + for issuer in self.cc_issuers: + my_token = issuer.generate() + namespace = issuer.get_namespace() + + if not my_token: + return "failed to get CC token" + + self.logger.info(f"site: {self.site_name} namespace: {namespace} got the token: {my_token}") + cc_info = {CC_TOKEN: my_token, CC_ISSUER: issuer, CC_NAMESPACE: namespace, CC_TOKEN_VALIDATED: True} + self.participant_cc_info[self.site_name].append(cc_info) + return "" def _client_to_check_participant_token(self, fl_ctx: FLContext) -> str: @@ -202,18 +219,17 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if err: return err - for p in participant_tokens: - self.participant_cc_info[p][CC_TOKEN_VALIDATED] = True fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=True, private=False) self.logger.info(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") return "" def _verify_participants(self, participants, participant_tokens): # if server token expired, then generates a new one - if not self.cc_issuer.verify(self.my_token): - self.my_token = self.cc_issuer.generate() - # participant_tokens = {self.site_name: {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace()}} - participant_tokens[self.site_name] = {CC_TOKEN: self.my_token, CC_NAMESPACE: self.cc_issuer.get_namespace()} + self._handle_expired_tokens() + + site_cc_info = self.participant_cc_info[self.site_name] + participant_tokens[self.site_name] = self._get_participant_tokens(site_cc_info) + for p in participants: assert isinstance(p, str) if p == self.site_name: @@ -221,18 +237,32 @@ def _verify_participants(self, participants, participant_tokens): if p not in self.participant_cc_info: return f"no token available for participant {p}" if self.participant_cc_info.get(p): - participant_tokens[p] = self.participant_cc_info[p] + participant_tokens[p] = self._get_participant_tokens(self.participant_cc_info[p]) else: - participant_tokens[p] = {} + participant_tokens[p] = [{}] return self._validate_participants_tokens(participant_tokens) + def _get_participant_tokens(self, site_cc_info): + cc_info = [] + for i in site_cc_info: + namespace = i.get(CC_NAMESPACE) + token = i.get(CC_TOKEN) + cc_info.append({CC_TOKEN: token, CC_NAMESPACE: namespace, CC_TOKEN_VALIDATED: False}) + return cc_info + + def _handle_expired_tokens(self): + site_cc_info = self.participant_cc_info[self.site_name] + for i in site_cc_info: + issuer = i.get(CC_ISSUER) + token = i.get(CC_TOKEN) + if not issuer.verify(token): + token = issuer.generate() + i[CC_TOKEN] = token + self.logger.info(f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}") + def _validate_participants_tokens(self, participants) -> str: self.logger.debug(f"Validating participant tokens {participants=}") - result = self._validate_participants(participants) - assert isinstance(result, dict) - for p in result: - self.participant_cc_info[p] = {CC_TOKEN: participants[p], CC_TOKEN_VALIDATED: True} - invalid_participant_list = [k for k, v in self.participant_cc_info.items() if v[CC_TOKEN_VALIDATED] is False] + result, invalid_participant_list = self._validate_participants(participants) if invalid_participant_list: invalid_participant_string = ",".join(invalid_participant_list) self.logger.debug(f"{invalid_participant_list=}") @@ -240,17 +270,23 @@ def _validate_participants_tokens(self, participants) -> str: else: return "" - def _validate_participants(self, participants: Dict[str, Dict[str, str]]) -> Dict[str, bool]: + def _validate_participants(self, participants: Dict[str, List[Dict[str, str]]]) -> (Dict[str, bool], List[str]): result = {} + invalid_participant_list = [] if not participants: return result - for k, v in participants.items(): - token = v.get(CC_TOKEN, "") - verifier = self.cc_verifiers.get(v.get(CC_NAMESPACE, ""), None) - if verifier and verifier.verify(token): - result[k] = True + + for k, cc_info in participants.items(): + for v in cc_info: + token = v.get(CC_TOKEN, "") + namespace = v.get(CC_NAMESPACE, "") + verifier = self.cc_verifiers.get(namespace, None) + if verifier and verifier.verify(token): + result[k + "." + namespace] = True + else: + invalid_participant_list.append(k + " namespace: {" + namespace + "}") self.logger.info(f"CC - results from validating participants' tokens: {result}") - return result + return result, invalid_participant_list def _not_authorize_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") From 91e3f40ea02fd556376c610192f210564c765ea3 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Mon, 4 Mar 2024 10:47:01 -0500 Subject: [PATCH 22/44] optimized CCManager. --- nvflare/app_opt/confidential_computing/cc_manager.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 6a0c8a3e15..6dc0ae216c 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -30,11 +30,11 @@ class CCManager(FLComponent): - def __init__(self, cc_issuer_ids: str, cc_verifier_ids: [str]): + def __init__(self, cc_issuer_ids: [str], cc_verifier_ids: [str]): """Manage all confidential computing related tasks. This manager does the following tasks: - obtaining its own GPU CC token + obtaining its own CC token preparing the token to the server keeping clients' tokens in server validating all tokens in the entire NVFlare system @@ -75,7 +75,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): # Server side self._remove_client_token(fl_ctx) elif event_type == EventType.BEFORE_CHECK_RESOURCE_MANAGER: - # Client side, check resources before job scheduled + # Client side: check resources before job scheduled try: err = self._client_to_check_participant_token(fl_ctx) except: @@ -85,7 +85,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): if err: self._not_authorize_job(err, fl_ctx) elif event_type == EventType.BEFORE_CHECK_CLIENT_RESOURCES: - # Server side, job scheduler check client resources + # Server side: job scheduler check client resources try: err = self._server_to_check_client_token(fl_ctx) except: @@ -116,7 +116,7 @@ def _setup_cc_authorizers(self, fl_ctx): self.cc_verifiers[namespace] = authorizer def _prepare_token_for_login(self, fl_ctx: FLContext): - # client side, if token expired then generate a new one + # client side: if token expired then generate a new one self._handle_expired_tokens() site_cc_info = self.participant_cc_info[self.site_name] @@ -257,6 +257,8 @@ def _handle_expired_tokens(self): token = i.get(CC_TOKEN) if not issuer.verify(token): token = issuer.generate() + if not token: + raise RuntimeError(f"{self.site_name} failed to generate a new CC token") i[CC_TOKEN] = token self.logger.info(f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}") From 6206452b066821958153aecc9c180f9be3489c8d Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Mon, 4 Mar 2024 12:08:14 -0500 Subject: [PATCH 23/44] updated the _verify_participants() logic. --- .../app_opt/confidential_computing/cc_manager.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 6dc0ae216c..d84d37c2dd 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -38,6 +38,8 @@ def __init__(self, cc_issuer_ids: [str], cc_verifier_ids: [str]): preparing the token to the server keeping clients' tokens in server validating all tokens in the entire NVFlare system + not allowing the system to start if failed to get CC token + shutdown the running jobs if CC tokens expired Args: @@ -155,8 +157,7 @@ def _verify_running_jobs(self, fl_ctx): for _, client in job_participants.items(): participants.append(client.name) - participant_tokens = {} - err = self._verify_participants(participants, participant_tokens) + err, participant_tokens = self._verify_participants(participants) if err: engine.job_runner.stop_run(job_id, fl_ctx) self.logger.info(f"Stop Job: {job_id} with CC verification error: {err} ") @@ -214,8 +215,7 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if not isinstance(participants, list): return f"bad value for {FLContextKey.JOB_PARTICIPANTS} in fl_ctx: expect list bot got {type(participants)}" - participant_tokens = {} - err = self._verify_participants(participants, participant_tokens) + err, participant_tokens = self._verify_participants(participants) if err: return err @@ -223,10 +223,11 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: self.logger.info(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") return "" - def _verify_participants(self, participants, participant_tokens): + def _verify_participants(self, participants): # if server token expired, then generates a new one self._handle_expired_tokens() + participant_tokens = {} site_cc_info = self.participant_cc_info[self.site_name] participant_tokens[self.site_name] = self._get_participant_tokens(site_cc_info) @@ -240,7 +241,7 @@ def _verify_participants(self, participants, participant_tokens): participant_tokens[p] = self._get_participant_tokens(self.participant_cc_info[p]) else: participant_tokens[p] = [{}] - return self._validate_participants_tokens(participant_tokens) + return self._validate_participants_tokens(participant_tokens), participant_tokens def _get_participant_tokens(self, site_cc_info): cc_info = [] From 51226d69d4428d8b4965940afb3f8b376cf829e0 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Mon, 4 Mar 2024 16:02:55 -0500 Subject: [PATCH 24/44] set up the proper fl_ctx for admin send_requests(). --- nvflare/apis/server_engine_spec.py | 3 +- .../job_schedulers/job_scheduler.py | 5 +++- .../confidential_computing/cc_manager.py | 4 +-- nvflare/private/fed/server/admin.py | 30 +++++++++++-------- nvflare/private/fed/server/cmd_utils.py | 3 +- nvflare/private/fed/server/job_runner.py | 3 +- nvflare/private/fed/server/server_engine.py | 13 ++++---- .../job_schedulers/job_scheduler_test.py | 16 +++++----- 8 files changed, 44 insertions(+), 33 deletions(-) diff --git a/nvflare/apis/server_engine_spec.py b/nvflare/apis/server_engine_spec.py index 1b1e8c14c4..f13867b07b 100644 --- a/nvflare/apis/server_engine_spec.py +++ b/nvflare/apis/server_engine_spec.py @@ -187,7 +187,7 @@ def check_client_resources( @abstractmethod def cancel_client_resources( - self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] + self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext ): """Cancels the request resources for the job. @@ -195,6 +195,7 @@ def cancel_client_resources( resource_check_results: A dict of {client_name: client_check_result} where client_check_result is a tuple of (is_resource_enough, resource reserve token if any) resource_reqs: A dict of {client_name: resource requirements dict} + fl_ctx: FLContext """ pass diff --git a/nvflare/app_common/job_schedulers/job_scheduler.py b/nvflare/app_common/job_schedulers/job_scheduler.py index 2619b98d68..764a8e372f 100644 --- a/nvflare/app_common/job_schedulers/job_scheduler.py +++ b/nvflare/app_common/job_schedulers/job_scheduler.py @@ -93,7 +93,7 @@ def _cancel_resources( if not isinstance(engine, ServerEngineSpec): raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngineSpec, but got {type(engine)}.") - engine.cancel_client_resources(resource_check_results, resource_reqs) + engine.cancel_client_resources(resource_check_results, resource_reqs, fl_ctx) self.log_debug(fl_ctx, f"cancel client resources using check results: {resource_check_results}") return False, None @@ -164,6 +164,9 @@ def _try_job(self, job: Job, fl_ctx: FLContext) -> (int, Optional[Dict[str, Disp self.log_info(fl_ctx, f"Job {job.job_id} can't be scheduled: {block_reason}") return SCHEDULE_RESULT_NO_RESOURCE, None, block_reason + PEER_CTX_CC_TOKEN = "_peer_ctx_cc_token" + cc_peer_ctx = fl_ctx.get_prop(key=PEER_CTX_CC_TOKEN) + self.logger.info(f"++++++++++ {cc_peer_ctx}") resource_check_results = self._check_client_resources(job=job, resource_reqs=resource_reqs, fl_ctx=fl_ctx) self.fire_event(EventType.AFTER_CHECK_CLIENT_RESOURCES, fl_ctx) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index d84d37c2dd..8a1fc4d584 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -219,7 +219,7 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if err: return err - fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=True, private=False) + fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=False, private=False) self.logger.info(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") return "" @@ -294,7 +294,7 @@ def _validate_participants(self, participants: Dict[str, List[Dict[str, str]]]) def _not_authorize_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_REASON, value=reason, sticky=False) + fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=reason, sticky=False) fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) def _block_job(self, reason: str, fl_ctx: FLContext): diff --git a/nvflare/private/fed/server/admin.py b/nvflare/private/fed/server/admin.py index 7581502d81..a8b40240c7 100644 --- a/nvflare/private/fed/server/admin.py +++ b/nvflare/private/fed/server/admin.py @@ -18,6 +18,7 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import ServerCommandKey +from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import ReservedHeaderKey from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx from nvflare.fuel.f3.cellnet.cell import Cell @@ -231,11 +232,12 @@ def send_request_to_client(self, req: Message, client_token: str, timeout_secs=2 if not isinstance(req, Message): raise TypeError("request must be Message but got {}".format(type(req))) reqs = {client_token: req} - replies = self.send_requests(reqs, timeout_secs=timeout_secs) - if replies is None or len(replies) <= 0: - return None - else: - return replies[0] + with self.sai.new_context() as fl_ctx: + replies = self.send_requests(reqs, fl_ctx, timeout_secs=timeout_secs) + if replies is None or len(replies) <= 0: + return None + else: + return replies[0] def send_requests_and_get_reply_dict(self, requests: dict, timeout_secs=2.0) -> dict: """Send requests to clients @@ -252,12 +254,13 @@ def send_requests_and_get_reply_dict(self, requests: dict, timeout_secs=2.0) -> for token, _ in requests.items(): result[token] = None - replies = self.send_requests(requests, timeout_secs=timeout_secs) - for r in replies: - result[r.client_token] = r.reply + with self.sai.new_context() as fl_ctx: + replies = self.send_requests(requests, fl_ctx, timeout_secs=timeout_secs) + for r in replies: + result[r.client_token] = r.reply return result - def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [ClientReply]: + def send_requests(self, requests: dict, fl_ctx: FLContext, timeout_secs=2.0, optional=False) -> [ClientReply]: """Send requests to clients. NOTE:: @@ -268,6 +271,7 @@ def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [Cl Args: requests: A dict of requests: {client token: request or list of requests} + fl_ctx: FLContext timeout_secs: how long to wait for reply before timeout optional: whether the requests are optional @@ -276,10 +280,10 @@ def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [Cl """ for _, request in requests.items(): - with self.sai.new_context() as fl_ctx: - self.sai.fire_event(EventType.BEFORE_SEND_ADMIN_COMMAND, fl_ctx) - shared_fl_ctx = gen_new_peer_ctx(fl_ctx) - request.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) + # with self.sai.new_context() as fl_ctx: + self.sai.fire_event(EventType.BEFORE_SEND_ADMIN_COMMAND, fl_ctx) + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + request.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) return send_requests( cell=self.cell, diff --git a/nvflare/private/fed/server/cmd_utils.py b/nvflare/private/fed/server/cmd_utils.py index d2d7bd6bfd..4cc8cc72ce 100644 --- a/nvflare/private/fed/server/cmd_utils.py +++ b/nvflare/private/fed/server/cmd_utils.py @@ -148,7 +148,8 @@ def send_request_to_clients(self, conn, message): cmd_timeout = conn.get_prop(ConnProps.CMD_TIMEOUT) if not cmd_timeout: cmd_timeout = admin_server.timeout - replies = admin_server.send_requests(requests, timeout_secs=cmd_timeout) + with admin_server.sai.new_context() as fl_ctx: + replies = admin_server.send_requests(requests, fl_ctx, timeout_secs=cmd_timeout) return replies diff --git a/nvflare/private/fed/server/job_runner.py b/nvflare/private/fed/server/job_runner.py index ae2cb7d568..187cb2f359 100644 --- a/nvflare/private/fed/server/job_runner.py +++ b/nvflare/private/fed/server/job_runner.py @@ -49,7 +49,8 @@ def _send_to_clients(admin_server, client_sites: List[str], engine, message, tim if timeout is None: timeout = admin_server.timeout - replies = admin_server.send_requests(requests, timeout_secs=timeout, optional=optional) + with admin_server.sai.new_context() as fl_ctx: + replies = admin_server.send_requests(requests, fl_ctx, timeout_secs=timeout, optional=optional) return replies diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index d6b636b575..f36beb61a0 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -721,8 +721,8 @@ def reset_errors(self, job_id) -> str: return f"reset the server error stats for job: {job_id}" - def _send_admin_requests(self, requests, timeout_secs=10) -> List[ClientReply]: - return self.server.admin_server.send_requests(requests, timeout_secs=timeout_secs) + def _send_admin_requests(self, requests, fl_ctx: FLContext, timeout_secs=10) -> List[ClientReply]: + return self.server.admin_server.send_requests(requests, fl_ctx, timeout_secs=timeout_secs) def check_client_resources(self, job: Job, resource_reqs, fl_ctx: FLContext) -> Dict[str, Tuple[bool, str]]: requests = {} @@ -737,7 +737,7 @@ def check_client_resources(self, job: Job, resource_reqs, fl_ctx: FLContext) -> requests.update({client.token: request}) replies = [] if requests: - replies = self._send_admin_requests(requests, 15) + replies = self._send_admin_requests(requests, fl_ctx, 15) result = {} for r in replies: site_name = r.client_name @@ -765,7 +765,7 @@ def _make_message_for_check_resource(self, job, resource_requirements, fl_ctx): return request def cancel_client_resources( - self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] + self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext ): requests = {} for site_name, result in resource_check_results.items(): @@ -778,7 +778,7 @@ def cancel_client_resources( if client: requests.update({client.token: request}) if requests: - _ = self._send_admin_requests(requests) + _ = self._send_admin_requests(requests, fl_ctx) def start_client_job(self, job_id, client_sites): requests = {} @@ -793,7 +793,8 @@ def start_client_job(self, job_id, client_sites): requests.update({client.token: request}) replies = [] if requests: - replies = self._send_admin_requests(requests, timeout_secs=20) + with self.new_context() as fl_ctx: + replies = self._send_admin_requests(requests, fl_ctx, timeout_secs=20) return replies def stop_all_jobs(self): diff --git a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py index df6c63a912..a16a3fb56e 100644 --- a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py +++ b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py @@ -136,15 +136,15 @@ def get_client_name_from_token(self, token): return self.clients.get(token) def cancel_client_resources( - self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] + self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext ): - with self.new_context() as fl_ctx: - for site_name, result in resource_check_results.items(): - check_result, token = result - if check_result and token: - self.clients[site_name].resource_manager.cancel_resources( - resource_requirement=resource_reqs[site_name], token=token, fl_ctx=fl_ctx - ) + # with self.new_context() as fl_ctx: + for site_name, result in resource_check_results.items(): + check_result, token = result + if check_result and token: + self.clients[site_name].resource_manager.cancel_resources( + resource_requirement=resource_reqs[site_name], token=token, fl_ctx=fl_ctx + ) def update_job_run_status(self): pass From ed770ac2bf8483b27cbf9bb099efe0e091cad208 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Mon, 4 Mar 2024 16:17:24 -0500 Subject: [PATCH 25/44] Add proper fl_ctx. --- nvflare/apis/server_engine_spec.py | 3 ++- nvflare/private/fed/server/job_runner.py | 2 +- nvflare/private/fed/server/server_engine.py | 5 ++--- .../app_common/job_schedulers/job_scheduler_test.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nvflare/apis/server_engine_spec.py b/nvflare/apis/server_engine_spec.py index f13867b07b..004e623056 100644 --- a/nvflare/apis/server_engine_spec.py +++ b/nvflare/apis/server_engine_spec.py @@ -154,12 +154,13 @@ def restore_components(self, snapshot: RunSnapshot, fl_ctx: FLContext): pass @abstractmethod - def start_client_job(self, job_id, client_sites): + def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): """To send the start client run commands to the clients Args: client_sites: client sites job_id: job_id + fl_ctx: FLContext Returns: diff --git a/nvflare/private/fed/server/job_runner.py b/nvflare/private/fed/server/job_runner.py index 187cb2f359..15a9bea833 100644 --- a/nvflare/private/fed/server/job_runner.py +++ b/nvflare/private/fed/server/job_runner.py @@ -249,7 +249,7 @@ def _start_run(self, job_id: str, job: Job, client_sites: Dict[str, DispatchInfo if err: raise RuntimeError(f"Could not start the server App for job: {job_id}.") - replies = engine.start_client_job(job_id, client_sites) + replies = engine.start_client_job(job_id, client_sites, fl_ctx) client_sites_names = list(client_sites.keys()) check_client_replies(replies=replies, client_sites=client_sites_names, command=f"start job ({job_id})") display_sites = ",".join(client_sites_names) diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index f36beb61a0..ed52b4e95f 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -780,7 +780,7 @@ def cancel_client_resources( if requests: _ = self._send_admin_requests(requests, fl_ctx) - def start_client_job(self, job_id, client_sites): + def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): requests = {} for site, dispatch_info in client_sites.items(): resource_requirement = dispatch_info.resource_requirements @@ -793,8 +793,7 @@ def start_client_job(self, job_id, client_sites): requests.update({client.token: request}) replies = [] if requests: - with self.new_context() as fl_ctx: - replies = self._send_admin_requests(requests, fl_ctx, timeout_secs=20) + replies = self._send_admin_requests(requests, fl_ctx, timeout_secs=20) return replies def stop_all_jobs(self): diff --git a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py index a16a3fb56e..1f7f9ab48e 100644 --- a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py +++ b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py @@ -120,7 +120,7 @@ def persist_components(self, fl_ctx: FLContext, completed: bool): def restore_components(self, snapshot, fl_ctx: FLContext): pass - def start_client_job(self, job_id, client_sites): + def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): pass def check_client_resources( From f76fa1e544833a6d6cb4bb5f99bc869ae1097cec Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 5 Mar 2024 09:33:18 -0500 Subject: [PATCH 26/44] Refactor the CCManager. --- .../confidential_computing/cc_manager.py | 18 ++++-------------- .../{tdx_connector.py => tdx_authorizer.py} | 2 +- 2 files changed, 5 insertions(+), 15 deletions(-) rename nvflare/app_opt/confidential_computing/{tdx_connector.py => tdx_authorizer.py} (98%) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 8a1fc4d584..eb2b1e01b1 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -85,7 +85,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): err = "Participants unable to meet client CC requirements" finally: if err: - self._not_authorize_job(err, fl_ctx) + self._block_job(err, fl_ctx) elif event_type == EventType.BEFORE_CHECK_CLIENT_RESOURCES: # Server side: job scheduler check client resources try: @@ -96,21 +96,18 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): finally: if err: self._block_job(err, fl_ctx) - elif event_type == EventType.AFTER_CHECK_CLIENT_RESOURCES: - # Server side - fl_ctx.remove_prop(PEER_CTX_CC_TOKEN) def _setup_cc_authorizers(self, fl_ctx): engine = fl_ctx.get_engine() for i_id in self.cc_issuer_ids: issuer = engine.get_component(i_id) - if not isinstance(issuer, CCAuthorizer): + if not (isinstance(issuer, CCAuthorizer) and issuer.can_generate()): raise RuntimeError(f"cc_issuer_id {i_id} must be a CCAuthorizer, but got {issuer.__class__}") self.cc_issuers.append(issuer) for v_id in self.cc_verifier_ids: authorizer = engine.get_component(v_id) - if not isinstance(authorizer, CCAuthorizer): + if not (isinstance(authorizer, CCAuthorizer) and authorizer.can_verify()): raise RuntimeError(f"cc_authorizer_id {v_id} must be a CCAuthorizer, but got {authorizer.__class__}") namespace = authorizer.get_namespace() if namespace in self.cc_verifiers.keys(): @@ -277,8 +274,7 @@ def _validate_participants(self, participants: Dict[str, List[Dict[str, str]]]) result = {} invalid_participant_list = [] if not participants: - return result - + return result, invalid_participant_list for k, cc_info in participants.items(): for v in cc_info: token = v.get(CC_TOKEN, "") @@ -291,12 +287,6 @@ def _validate_participants(self, participants: Dict[str, List[Dict[str, str]]]) self.logger.info(f"CC - results from validating participants' tokens: {result}") return result, invalid_participant_list - def _not_authorize_job(self, reason: str, fl_ctx: FLContext): - job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") - self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=reason, sticky=False) - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) - def _block_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") diff --git a/nvflare/app_opt/confidential_computing/tdx_connector.py b/nvflare/app_opt/confidential_computing/tdx_authorizer.py similarity index 98% rename from nvflare/app_opt/confidential_computing/tdx_connector.py rename to nvflare/app_opt/confidential_computing/tdx_authorizer.py index 2cb657396c..33f318b2de 100644 --- a/nvflare/app_opt/confidential_computing/tdx_connector.py +++ b/nvflare/app_opt/confidential_computing/tdx_authorizer.py @@ -24,7 +24,7 @@ ERROR_FILE = "error.txt" -class TDXConnector(CCAuthorizer): +class TDXAuthorizer(CCAuthorizer): def __init__(self, tdx_cli_command: str, config_dir: str) -> None: super().__init__() self.tdx_cli_command = tdx_cli_command From 74a6059bcbc929526d0c5866bcd11592496e91f2 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 5 Mar 2024 14:55:58 -0500 Subject: [PATCH 27/44] Refactor the CCManager and TDX_authorizer. --- nvflare/app_opt/confidential_computing/cc_manager.py | 2 -- nvflare/app_opt/confidential_computing/tdx_authorizer.py | 4 ---- 2 files changed, 6 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index eb2b1e01b1..96dfd7621c 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -255,8 +255,6 @@ def _handle_expired_tokens(self): token = i.get(CC_TOKEN) if not issuer.verify(token): token = issuer.generate() - if not token: - raise RuntimeError(f"{self.site_name} failed to generate a new CC token") i[CC_TOKEN] = token self.logger.info(f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}") diff --git a/nvflare/app_opt/confidential_computing/tdx_authorizer.py b/nvflare/app_opt/confidential_computing/tdx_authorizer.py index 33f318b2de..6077febd2a 100644 --- a/nvflare/app_opt/confidential_computing/tdx_authorizer.py +++ b/nvflare/app_opt/confidential_computing/tdx_authorizer.py @@ -41,8 +41,6 @@ def generate(self) -> str: command = ['sudo', self.tdx_cli_command, '-c', self.config_file, 'token', "--no-eventlog"] subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) - # with open(token_file, "r") as f: - # token = f.readline() with open(error_file, "r") as f: if 'Error:' in f.read(): return "" @@ -59,8 +57,6 @@ def verify(self, token: str) -> bool: command = [self.tdx_cli_command, "verify", "--config", self.config_file, "--token", token] subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) - # with open(VERIFY_FILE, "r") as f: - # result = f.readline() with open(error_file, "r") as f: if 'Error:' in f.read(): return False From 313ed21d752c345f2f6943207ca8fb1232aa17c4 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 6 Mar 2024 13:47:44 -0500 Subject: [PATCH 28/44] Added TOKEN_EXPIRATION for each cc_issue in CCManager. --- .../confidential_computing/cc_manager.py | 83 ++++++++++++------- 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 96dfd7621c..57255c1415 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import threading +import time from typing import Dict, List from nvflare.apis.event_type import EventType @@ -28,9 +29,13 @@ CC_INFO = "_cc_info" CC_TOKEN_VALIDATED = "_cc_token_validated" +CC_ISSUER_ID = "issuer_id" +TOKEN_GENERATION_TIME = "token_generation_time" +TOKEN_EXPIRATION = "token_expiration" + class CCManager(FLComponent): - def __init__(self, cc_issuer_ids: [str], cc_verifier_ids: [str]): + def __init__(self, cc_issuers_conf: [str], cc_verifier_ids: [str], verify_frequency=600): """Manage all confidential computing related tasks. This manager does the following tasks: @@ -46,9 +51,11 @@ def __init__(self, cc_issuer_ids: [str], cc_verifier_ids: [str]): """ FLComponent.__init__(self) self.site_name = None - self.cc_issuer_ids = cc_issuer_ids + self.cc_issuers_conf = cc_issuers_conf self.cc_verifier_ids = cc_verifier_ids - self.cc_issuers = [] + self.verify_frequency = verify_frequency + self.verify_time = None + self.cc_issuers = {} self.cc_verifiers = {} self.participant_cc_info = {} # used by the Server to keep tokens of all clients @@ -59,7 +66,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): try: self._setup_cc_authorizers(fl_ctx) - err = self._prepare_for_attestation(fl_ctx) + err = self._generate_tokens(fl_ctx) except: self.log_exception(fl_ctx, "exception in attestation preparation") err = "exception in attestation preparation" @@ -69,7 +76,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): raise UnsafeComponentError(err) elif event_type == EventType.BEFORE_CLIENT_REGISTER or event_type == EventType.BEFORE_CLIENT_HEARTBEAT: # On client side - self._prepare_token_for_login(fl_ctx) + self._prepare_cc_info(fl_ctx) elif event_type == EventType.CLIENT_REGISTERED or event_type == EventType.AFTER_CLIENT_HEARTBEAT: # Server side self._add_client_token(fl_ctx) @@ -99,22 +106,24 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): def _setup_cc_authorizers(self, fl_ctx): engine = fl_ctx.get_engine() - for i_id in self.cc_issuer_ids: - issuer = engine.get_component(i_id) + for conf in self.cc_issuers_conf: + issuer_id = conf.get(CC_ISSUER_ID) + expiration = conf.get(TOKEN_EXPIRATION) + issuer = engine.get_component(issuer_id) if not (isinstance(issuer, CCAuthorizer) and issuer.can_generate()): - raise RuntimeError(f"cc_issuer_id {i_id} must be a CCAuthorizer, but got {issuer.__class__}") - self.cc_issuers.append(issuer) + raise RuntimeError(f"cc_issuer_id {issuer_id} must be a CCAuthorizer, but got {issuer.__class__}") + self.cc_issuers[issuer] = expiration for v_id in self.cc_verifier_ids: - authorizer = engine.get_component(v_id) - if not (isinstance(authorizer, CCAuthorizer) and authorizer.can_verify()): - raise RuntimeError(f"cc_authorizer_id {v_id} must be a CCAuthorizer, but got {authorizer.__class__}") - namespace = authorizer.get_namespace() + verifier = engine.get_component(v_id) + if not (isinstance(verifier, CCAuthorizer) and verifier.can_verify()): + raise RuntimeError(f"cc_authorizer_id {v_id} must be a CCAuthorizer, but got {verifier.__class__}") + namespace = verifier.get_namespace() if namespace in self.cc_verifiers.keys(): raise RuntimeError(f"Authorizer with namespace: {namespace} already exist.") - self.cc_verifiers[namespace] = authorizer + self.cc_verifiers[namespace] = verifier - def _prepare_token_for_login(self, fl_ctx: FLContext): + def _prepare_cc_info(self, fl_ctx: FLContext): # client side: if token expired then generate a new one self._handle_expired_tokens() @@ -141,23 +150,30 @@ def _add_client_token(self, fl_ctx: FLContext): self.participant_cc_info[token_owner] = peer_cc_info self.logger.info(f"Added CC client: {token_owner} tokens: {peer_cc_info}") - with self.lock: + self._verify_running_jobs(fl_ctx) + else: + if time.time() - self.verify_time > self.verify_frequency: self._verify_running_jobs(fl_ctx) def _verify_running_jobs(self, fl_ctx): engine = fl_ctx.get_engine() run_processes = engine.run_processes running_jobs = list(run_processes.keys()) - for job_id in running_jobs: - job_participants = run_processes[job_id].get(RunProcessKey.PARTICIPANTS) - participants = [] - for _, client in job_participants.items(): - participants.append(client.name) + with self.lock: + for job_id in running_jobs: + job_participants = run_processes[job_id].get(RunProcessKey.PARTICIPANTS) + participants = [] + for _, client in job_participants.items(): + participants.append(client.name) + + err, participant_tokens = self._verify_participants(participants) + if err: + # maybe shutdown the whole system here. leave the user to define the action + engine.job_runner.stop_run(job_id, fl_ctx) - err, participant_tokens = self._verify_participants(participants) - if err: - engine.job_runner.stop_run(job_id, fl_ctx) - self.logger.info(f"Stop Job: {job_id} with CC verification error: {err} ") + self.logger.info(f"Stop Job: {job_id} with CC verification error: {err} ") + + self.verify_time = time.time() def _remove_client_token(self, fl_ctx: FLContext): # server side @@ -166,13 +182,13 @@ def _remove_client_token(self, fl_ctx: FLContext): self.participant_cc_info.pop(token_owner) self.logger.info(f"Removed CC client: {token_owner}") - def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: + def _generate_tokens(self, fl_ctx: FLContext) -> str: # both server and client sides self.site_name = fl_ctx.get_identity_name() workspace_folder = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_site_config_dir() self.participant_cc_info[self.site_name] = [] - for issuer in self.cc_issuers: + for issuer, expiration in self.cc_issuers: my_token = issuer.generate() namespace = issuer.get_namespace() @@ -180,7 +196,12 @@ def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: return "failed to get CC token" self.logger.info(f"site: {self.site_name} namespace: {namespace} got the token: {my_token}") - cc_info = {CC_TOKEN: my_token, CC_ISSUER: issuer, CC_NAMESPACE: namespace, CC_TOKEN_VALIDATED: True} + cc_info = {CC_TOKEN: my_token, + CC_ISSUER: issuer, + CC_NAMESPACE: namespace, + TOKEN_GENERATION_TIME: time.time(), + TOKEN_EXPIRATION: expiration, + CC_TOKEN_VALIDATED: True} self.participant_cc_info[self.site_name].append(cc_info) return "" @@ -252,10 +273,12 @@ def _handle_expired_tokens(self): site_cc_info = self.participant_cc_info[self.site_name] for i in site_cc_info: issuer = i.get(CC_ISSUER) - token = i.get(CC_TOKEN) - if not issuer.verify(token): + token_generate_time = i.get(TOKEN_GENERATION_TIME) + expiration = i.get(TOKEN_EXPIRATION) + if time.time() - token_generate_time > expiration: token = issuer.generate() i[CC_TOKEN] = token + i[TOKEN_GENERATION_TIME] = time.time() self.logger.info(f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}") def _validate_participants_tokens(self, participants) -> str: From 40d70c667d9e4579c738ac4081cfbba9788bd667 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 6 Mar 2024 14:21:15 -0500 Subject: [PATCH 29/44] Fixed CC TOKEN_EXPIRATION error. --- nvflare/app_opt/confidential_computing/cc_manager.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 57255c1415..8e39353fb7 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -35,7 +35,7 @@ class CCManager(FLComponent): - def __init__(self, cc_issuers_conf: [str], cc_verifier_ids: [str], verify_frequency=600): + def __init__(self, cc_issuers_conf: [Dict[str, str]], cc_verifier_ids: [str], verify_frequency=600): """Manage all confidential computing related tasks. This manager does the following tasks: @@ -188,10 +188,12 @@ def _generate_tokens(self, fl_ctx: FLContext) -> str: workspace_folder = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_site_config_dir() self.participant_cc_info[self.site_name] = [] - for issuer, expiration in self.cc_issuers: + for issuer, expiration in self.cc_issuers.items(): my_token = issuer.generate() namespace = issuer.get_namespace() + if not isinstance(expiration, int): + raise ValueError(f"token_expiration value must be int, but got {expiration.__class__}") if not my_token: return "failed to get CC token" @@ -200,7 +202,7 @@ def _generate_tokens(self, fl_ctx: FLContext) -> str: CC_ISSUER: issuer, CC_NAMESPACE: namespace, TOKEN_GENERATION_TIME: time.time(), - TOKEN_EXPIRATION: expiration, + TOKEN_EXPIRATION: int(expiration), CC_TOKEN_VALIDATED: True} self.participant_cc_info[self.site_name].append(cc_info) From 0c3c188712e40b531e2db0e1e80be808691f82f2 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 6 Mar 2024 15:53:45 -0500 Subject: [PATCH 30/44] refactor the CCManager _prepare_cc_info() --- .../confidential_computing/cc_manager.py | 43 +++++++++++-------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 8e39353fb7..47289d8faf 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -59,6 +59,7 @@ def __init__(self, cc_issuers_conf: [Dict[str, str]], cc_verifier_ids: [str], ve self.cc_verifiers = {} self.participant_cc_info = {} # used by the Server to keep tokens of all clients + self.token_submitted = False self.lock = threading.Lock() def handle_event(self, event_type: str, fl_ctx: FLContext): @@ -127,33 +128,34 @@ def _prepare_cc_info(self, fl_ctx: FLContext): # client side: if token expired then generate a new one self._handle_expired_tokens() - site_cc_info = self.participant_cc_info[self.site_name] - cc_info = self._get_participant_tokens(site_cc_info) - fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) + if not self.token_submitted: + site_cc_info = self.participant_cc_info[self.site_name] + cc_info = self._get_participant_tokens(site_cc_info) + fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) + self.token_submitted = True def _add_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() - peer_cc_info = peer_ctx.get_prop(CC_INFO, [{CC_TOKEN: "", CC_NAMESPACE: ""}]) - new_tokens = [] - for i in peer_cc_info: - new_tokens.append(i[CC_TOKEN]) - - old_cc_info = self.participant_cc_info.get(token_owner) - old_tokens = [] - if old_cc_info: - for i in old_cc_info: - old_tokens.append(i[CC_TOKEN]) - - if not old_cc_info or set(new_tokens) != set(old_tokens): + peer_cc_info = peer_ctx.get_prop(CC_INFO) + # new_tokens = [] + # for i in peer_cc_info: + # new_tokens.append(i[CC_TOKEN]) + # + # old_cc_info = self.participant_cc_info.get(token_owner) + # old_tokens = [] + # if old_cc_info: + # for i in old_cc_info: + # old_tokens.append(i[CC_TOKEN]) + + # if not old_cc_info or set(new_tokens) != set(old_tokens): + if peer_cc_info: self.participant_cc_info[token_owner] = peer_cc_info self.logger.info(f"Added CC client: {token_owner} tokens: {peer_cc_info}") + if time.time() - self.verify_time > self.verify_frequency: self._verify_running_jobs(fl_ctx) - else: - if time.time() - self.verify_time > self.verify_frequency: - self._verify_running_jobs(fl_ctx) def _verify_running_jobs(self, fl_ctx): engine = fl_ctx.get_engine() @@ -205,6 +207,7 @@ def _generate_tokens(self, fl_ctx: FLContext) -> str: TOKEN_EXPIRATION: int(expiration), CC_TOKEN_VALIDATED: True} self.participant_cc_info[self.site_name].append(cc_info) + self.token_submitted = False return "" @@ -260,7 +263,7 @@ def _verify_participants(self, participants): if self.participant_cc_info.get(p): participant_tokens[p] = self._get_participant_tokens(self.participant_cc_info[p]) else: - participant_tokens[p] = [{}] + participant_tokens[p] = [{CC_TOKEN: "", CC_NAMESPACE: ""}] return self._validate_participants_tokens(participant_tokens), participant_tokens def _get_participant_tokens(self, site_cc_info): @@ -283,6 +286,8 @@ def _handle_expired_tokens(self): i[TOKEN_GENERATION_TIME] = time.time() self.logger.info(f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}") + self.token_submitted = False + def _validate_participants_tokens(self, participants) -> str: self.logger.debug(f"Validating participant tokens {participants=}") result, invalid_participant_list = self._validate_participants(participants) From affda4b6f5e4d1aa035e600efc0a0d035c0f5f4e Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 6 Mar 2024 16:24:53 -0500 Subject: [PATCH 31/44] Refactor. --- nvflare/app_opt/confidential_computing/cc_manager.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 47289d8faf..cbd9986a75 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -132,6 +132,7 @@ def _prepare_cc_info(self, fl_ctx: FLContext): site_cc_info = self.participant_cc_info[self.site_name] cc_info = self._get_participant_tokens(site_cc_info) fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) + self.logger.info("Sent the CC-tokens to server.") self.token_submitted = True def _add_client_token(self, fl_ctx: FLContext): @@ -154,7 +155,7 @@ def _add_client_token(self, fl_ctx: FLContext): self.participant_cc_info[token_owner] = peer_cc_info self.logger.info(f"Added CC client: {token_owner} tokens: {peer_cc_info}") - if time.time() - self.verify_time > self.verify_frequency: + if not self.verify_time or time.time() - self.verify_time > self.verify_frequency: self._verify_running_jobs(fl_ctx) def _verify_running_jobs(self, fl_ctx): @@ -258,8 +259,8 @@ def _verify_participants(self, participants): assert isinstance(p, str) if p == self.site_name: continue - if p not in self.participant_cc_info: - return f"no token available for participant {p}" + # if p not in self.participant_cc_info: + # return f"no token available for participant {p}" if self.participant_cc_info.get(p): participant_tokens[p] = self._get_participant_tokens(self.participant_cc_info[p]) else: @@ -286,7 +287,7 @@ def _handle_expired_tokens(self): i[TOKEN_GENERATION_TIME] = time.time() self.logger.info(f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}") - self.token_submitted = False + self.token_submitted = False def _validate_participants_tokens(self, participants) -> str: self.logger.debug(f"Validating participant tokens {participants=}") From 2dc7df349100a5b8c8fbf26696ea2091b352f97d Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 7 Mar 2024 09:40:30 -0500 Subject: [PATCH 32/44] refactor the cc tokens periodic verification. --- .../confidential_computing/cc_manager.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index cbd9986a75..765996ac43 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -140,17 +140,7 @@ def _add_client_token(self, fl_ctx: FLContext): peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() peer_cc_info = peer_ctx.get_prop(CC_INFO) - # new_tokens = [] - # for i in peer_cc_info: - # new_tokens.append(i[CC_TOKEN]) - # - # old_cc_info = self.participant_cc_info.get(token_owner) - # old_tokens = [] - # if old_cc_info: - # for i in old_cc_info: - # old_tokens.append(i[CC_TOKEN]) - - # if not old_cc_info or set(new_tokens) != set(old_tokens): + if peer_cc_info: self.participant_cc_info[token_owner] = peer_cc_info self.logger.info(f"Added CC client: {token_owner} tokens: {peer_cc_info}") @@ -182,8 +172,9 @@ def _remove_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() - self.participant_cc_info.pop(token_owner) - self.logger.info(f"Removed CC client: {token_owner}") + if token_owner in self.participant_cc_info.keys(): + self.participant_cc_info.pop(token_owner) + self.logger.info(f"Removed CC client: {token_owner}") def _generate_tokens(self, fl_ctx: FLContext) -> str: # both server and client sides From 6934884aec678f5cecfa44412b6ba6c5ef846031 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 7 Mar 2024 15:11:20 -0500 Subject: [PATCH 33/44] added critical_level for CCManager. --- nvflare/apis/fl_constant.py | 1 + .../job_schedulers/job_scheduler.py | 1 + .../confidential_computing/cc_manager.py | 62 ++++++++++++++++--- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 4f1f32b0ce..2065f715c3 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -150,6 +150,7 @@ class FLContextKey(object): COMMUNICATION_ERROR = "Flare_communication_error__" UNAUTHENTICATED = "Flare_unauthenticated__" CLIENT_RESOURCE_SPECS = "__client_resource_specs" + CLIENT_RESOURCE_RESULT = "__client_resource_result" JOB_PARTICIPANTS = "__job_participants" JOB_BLOCK_REASON = "__job_block_reason" # why the job should be blocked from scheduling SSID = "__ssid__" diff --git a/nvflare/app_common/job_schedulers/job_scheduler.py b/nvflare/app_common/job_schedulers/job_scheduler.py index 764a8e372f..adb40d3558 100644 --- a/nvflare/app_common/job_schedulers/job_scheduler.py +++ b/nvflare/app_common/job_schedulers/job_scheduler.py @@ -168,6 +168,7 @@ def _try_job(self, job: Job, fl_ctx: FLContext) -> (int, Optional[Dict[str, Disp cc_peer_ctx = fl_ctx.get_prop(key=PEER_CTX_CC_TOKEN) self.logger.info(f"++++++++++ {cc_peer_ctx}") resource_check_results = self._check_client_resources(job=job, resource_reqs=resource_reqs, fl_ctx=fl_ctx) + fl_ctx.set_prop(FLContextKey.CLIENT_RESOURCE_RESULT, resource_check_results, private=True, sticky=False) self.fire_event(EventType.AFTER_CHECK_CLIENT_RESOURCES, fl_ctx) if not resource_check_results: diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 765996ac43..5d3f93e3ad 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -21,6 +21,8 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer +from nvflare.fuel.hci.conn import Connection +from nvflare.private.fed.server.training_cmds import TrainingCommandModule PEER_CTX_CC_TOKEN = "_peer_ctx_cc_token" CC_TOKEN = "_cc_token" @@ -28,14 +30,19 @@ CC_NAMESPACE = "_cc_namespace" CC_INFO = "_cc_info" CC_TOKEN_VALIDATED = "_cc_token_validated" +CC_VERIFY_ERROR = "_cc_verify_error." CC_ISSUER_ID = "issuer_id" TOKEN_GENERATION_TIME = "token_generation_time" TOKEN_EXPIRATION = "token_expiration" +SHUTDOWN_SYSTEM = 1 +SHUTDOWN_JOB = 2 + class CCManager(FLComponent): - def __init__(self, cc_issuers_conf: [Dict[str, str]], cc_verifier_ids: [str], verify_frequency=600): + def __init__(self, cc_issuers_conf: [Dict[str, str]], cc_verifier_ids: [str], + verify_frequency=600, critical_level=SHUTDOWN_JOB): """Manage all confidential computing related tasks. This manager does the following tasks: @@ -53,7 +60,15 @@ def __init__(self, cc_issuers_conf: [Dict[str, str]], cc_verifier_ids: [str], ve self.site_name = None self.cc_issuers_conf = cc_issuers_conf self.cc_verifier_ids = cc_verifier_ids - self.verify_frequency = verify_frequency + + if not isinstance(verify_frequency, int): + raise ValueError(f"verify_frequency must be in, but got {verify_frequency.__class__}") + self.verify_frequency = int(verify_frequency) + + self.critical_level = critical_level + if self.critical_level not in [SHUTDOWN_SYSTEM, SHUTDOWN_JOB]: + raise ValueError(f"critical_level must be in [{SHUTDOWN_SYSTEM}, {SHUTDOWN_JOB}]. But got {critical_level}") + self.verify_time = None self.cc_issuers = {} self.cc_verifiers = {} @@ -103,7 +118,19 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): err = "Clients unable to meet server CC requirements" finally: if err: - self._block_job(err, fl_ctx) + if self.critical_level == SHUTDOWN_JOB: + self._block_job(err, fl_ctx) + else: + threading.Thread(target=self._shutdown_system, args=[err, fl_ctx]).start() + elif event_type == EventType.AFTER_CHECK_CLIENT_RESOURCES: + client_resource_result = fl_ctx.get_prop(FLContextKey.CLIENT_RESOURCE_RESULT) + if client_resource_result: + for site_name, check_result in client_resource_result.items(): + is_resource_enough, reason = check_result + if not is_resource_enough and reason.startswith(CC_VERIFY_ERROR) \ + and self.critical_level == SHUTDOWN_SYSTEM: + threading.Thread(target=self._shutdown_system, args=[reason, fl_ctx]).start() + break def _setup_cc_authorizers(self, fl_ctx): engine = fl_ctx.get_engine() @@ -161,10 +188,12 @@ def _verify_running_jobs(self, fl_ctx): err, participant_tokens = self._verify_participants(participants) if err: - # maybe shutdown the whole system here. leave the user to define the action - engine.job_runner.stop_run(job_id, fl_ctx) - - self.logger.info(f"Stop Job: {job_id} with CC verification error: {err} ") + if self.critical_level == SHUTDOWN_JOB: + # maybe shutdown the whole system here. leave the user to define the action + engine.job_runner.stop_run(job_id, fl_ctx) + self.logger.info(f"Stop Job: {job_id} with CC verification error: {err} ") + else: + threading.Thread(target=self._shutdown_system, args=[err, fl_ctx]).start() self.verify_time = time.time() @@ -310,5 +339,22 @@ def _validate_participants(self, participants: Dict[str, List[Dict[str, str]]]) def _block_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=reason, sticky=False) + fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=CC_VERIFY_ERROR+reason, sticky=False) fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) + + def _shutdown_system(self, reason: str, fl_ctx: FLContext): + engine = fl_ctx.get_engine() + run_processes = engine.run_processes + running_jobs = list(run_processes.keys()) + for job_id in running_jobs: + engine.job_runner.stop_run(job_id, fl_ctx) + + conn = Connection({}, engine.server.admin_server) + conn.app_ctx = engine + + cmd = TrainingCommandModule() + args = ["shutdown", "all"] + cmd.validate_command_targets(conn, args[1:]) + cmd.shutdown(conn, args) + + self.logger.error(f"CC system shutdown! due to reason: {reason}") From 199eb1e52cd5b67b4632eaa4e347045c7fc7b3f7 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 8 Mar 2024 16:07:12 -0500 Subject: [PATCH 34/44] codestyle fix. --- .../confidential_computing/cc_authorizer.py | 1 + .../confidential_computing/cc_manager.py | 36 ++++++++++++------- .../confidential_computing/tdx_authorizer.py | 6 ++-- .../fed/app/deployer/server_deployer.py | 2 +- nvflare/private/fed/client/communicator.py | 4 +-- nvflare/private/fed/client/scheduler_cmds.py | 4 +-- 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_authorizer.py b/nvflare/app_opt/confidential_computing/cc_authorizer.py index fb1fc525bb..c09a9513a0 100644 --- a/nvflare/app_opt/confidential_computing/cc_authorizer.py +++ b/nvflare/app_opt/confidential_computing/cc_authorizer.py @@ -13,6 +13,7 @@ # limitations under the License. # import os.path + class CCAuthorizer: def can_generate(self) -> bool: """This indicates if the authorizer can generate a CC token or not. diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 5d3f93e3ad..a2b6642722 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -41,8 +41,13 @@ class CCManager(FLComponent): - def __init__(self, cc_issuers_conf: [Dict[str, str]], cc_verifier_ids: [str], - verify_frequency=600, critical_level=SHUTDOWN_JOB): + def __init__( + self, + cc_issuers_conf: [Dict[str, str]], + cc_verifier_ids: [str], + verify_frequency=600, + critical_level=SHUTDOWN_JOB, + ): """Manage all confidential computing related tasks. This manager does the following tasks: @@ -127,8 +132,11 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): if client_resource_result: for site_name, check_result in client_resource_result.items(): is_resource_enough, reason = check_result - if not is_resource_enough and reason.startswith(CC_VERIFY_ERROR) \ - and self.critical_level == SHUTDOWN_SYSTEM: + if ( + not is_resource_enough + and reason.startswith(CC_VERIFY_ERROR) + and self.critical_level == SHUTDOWN_SYSTEM + ): threading.Thread(target=self._shutdown_system, args=[reason, fl_ctx]).start() break @@ -221,12 +229,14 @@ def _generate_tokens(self, fl_ctx: FLContext) -> str: return "failed to get CC token" self.logger.info(f"site: {self.site_name} namespace: {namespace} got the token: {my_token}") - cc_info = {CC_TOKEN: my_token, - CC_ISSUER: issuer, - CC_NAMESPACE: namespace, - TOKEN_GENERATION_TIME: time.time(), - TOKEN_EXPIRATION: int(expiration), - CC_TOKEN_VALIDATED: True} + cc_info = { + CC_TOKEN: my_token, + CC_ISSUER: issuer, + CC_NAMESPACE: namespace, + TOKEN_GENERATION_TIME: time.time(), + TOKEN_EXPIRATION: int(expiration), + CC_TOKEN_VALIDATED: True, + } self.participant_cc_info[self.site_name].append(cc_info) self.token_submitted = False @@ -305,7 +315,9 @@ def _handle_expired_tokens(self): token = issuer.generate() i[CC_TOKEN] = token i[TOKEN_GENERATION_TIME] = time.time() - self.logger.info(f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}") + self.logger.info( + f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}" + ) self.token_submitted = False @@ -339,7 +351,7 @@ def _validate_participants(self, participants: Dict[str, List[Dict[str, str]]]) def _block_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=CC_VERIFY_ERROR+reason, sticky=False) + fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=CC_VERIFY_ERROR + reason, sticky=False) fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) def _shutdown_system(self, reason: str, fl_ctx: FLContext): diff --git a/nvflare/app_opt/confidential_computing/tdx_authorizer.py b/nvflare/app_opt/confidential_computing/tdx_authorizer.py index 6077febd2a..3825588dd8 100644 --- a/nvflare/app_opt/confidential_computing/tdx_authorizer.py +++ b/nvflare/app_opt/confidential_computing/tdx_authorizer.py @@ -38,11 +38,11 @@ def generate(self) -> str: error_file = os.path.join(self.config_dir, ERROR_FILE) err_out = open(error_file, "w") - command = ['sudo', self.tdx_cli_command, '-c', self.config_file, 'token', "--no-eventlog"] + command = ["sudo", self.tdx_cli_command, "-c", self.config_file, "token", "--no-eventlog"] subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) with open(error_file, "r") as f: - if 'Error:' in f.read(): + if "Error:" in f.read(): return "" else: with open(token_file, "r") as f: @@ -58,7 +58,7 @@ def verify(self, token: str) -> bool: subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) with open(error_file, "r") as f: - if 'Error:' in f.read(): + if "Error:" in f.read(): return False return True diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index a2da643518..40924392b8 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -16,7 +16,7 @@ import threading from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import SystemComponents, FLContextKey +from nvflare.apis.fl_constant import FLContextKey, SystemComponents from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.apis.workspace import Workspace from nvflare.private.fed.server.fed_server import FederatedServer diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index f4d3253a16..1a94384d5d 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -304,7 +304,7 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: task_name, }, - shareable + shareable, ) try: result = self.cell.send_request( @@ -347,7 +347,7 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C CellMessageHeaderKeys.PROJECT_NAME: task_name, CellMessageHeaderKeys.JOB_IDS: job_ids, }, - shareable + shareable, ) try: diff --git a/nvflare/private/fed/client/scheduler_cmds.py b/nvflare/private/fed/client/scheduler_cmds.py index 8ec3e0d855..f3528082df 100644 --- a/nvflare/private/fed/client/scheduler_cmds.py +++ b/nvflare/private/fed/client/scheduler_cmds.py @@ -16,9 +16,9 @@ from typing import List from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLContextKey, ReturnCode, SystemComponents, ServerCommandKey +from nvflare.apis.fl_constant import FLContextKey, ReturnCode, ServerCommandKey, SystemComponents from nvflare.apis.resource_manager_spec import ResourceConsumerSpec, ResourceManagerSpec -from nvflare.apis.shareable import Shareable, ReservedHeaderKey +from nvflare.apis.shareable import ReservedHeaderKey, Shareable from nvflare.private.admin_defs import Message from nvflare.private.defs import ERROR_MSG_PREFIX, RequestHeader, SysCommandTopic, TrainingTopic from nvflare.private.fed.client.admin import RequestProcessor From 093673329bd7380ee5cd4c95ea049908b3d488df Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 8 Mar 2024 16:08:50 -0500 Subject: [PATCH 35/44] removed no used import. --- nvflare/private/fed/client/scheduler_cmds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nvflare/private/fed/client/scheduler_cmds.py b/nvflare/private/fed/client/scheduler_cmds.py index f3528082df..7ea2e8d55b 100644 --- a/nvflare/private/fed/client/scheduler_cmds.py +++ b/nvflare/private/fed/client/scheduler_cmds.py @@ -18,7 +18,7 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey, ReturnCode, ServerCommandKey, SystemComponents from nvflare.apis.resource_manager_spec import ResourceConsumerSpec, ResourceManagerSpec -from nvflare.apis.shareable import ReservedHeaderKey, Shareable +from nvflare.apis.shareable import Shareable from nvflare.private.admin_defs import Message from nvflare.private.defs import ERROR_MSG_PREFIX, RequestHeader, SysCommandTopic, TrainingTopic from nvflare.private.fed.client.admin import RequestProcessor From 0b28480eea0537af99942eae53cfd39b1f233584 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 8 Mar 2024 16:14:36 -0500 Subject: [PATCH 36/44] removed no use import. --- nvflare/private/fed/server/admin.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nvflare/private/fed/server/admin.py b/nvflare/private/fed/server/admin.py index a8b40240c7..71fc765939 100644 --- a/nvflare/private/fed/server/admin.py +++ b/nvflare/private/fed/server/admin.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import threading import time from typing import List, Optional @@ -19,7 +18,6 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import ServerCommandKey from nvflare.apis.fl_context import FLContext -from nvflare.apis.shareable import ReservedHeaderKey from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.net_agent import NetAgent From 000735c5ae21623cdaba0209d84a62caa9101c55 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 8 Mar 2024 18:00:12 -0500 Subject: [PATCH 37/44] Fixed the unitest. --- tests/unit_test/private/fed/server/fed_server_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit_test/private/fed/server/fed_server_test.py b/tests/unit_test/private/fed/server/fed_server_test.py index abf8a21974..cda41d824a 100644 --- a/tests/unit_test/private/fed/server/fed_server_test.py +++ b/tests/unit_test/private/fed/server/fed_server_test.py @@ -16,6 +16,7 @@ import pytest +from nvflare.apis.shareable import Shareable from nvflare.private.defs import CellMessageHeaderKeys, new_cell_message from nvflare.private.fed.server.fed_server import FederatedServer from nvflare.private.fed.server.server_state import ColdState, HotState @@ -46,7 +47,8 @@ def test_heart_beat_abort_jobs(self, server_state, expected): CellMessageHeaderKeys.CLIENT_NAME: "client_name", CellMessageHeaderKeys.PROJECT_NAME: "task_name", CellMessageHeaderKeys.JOB_IDS: ["extra_job"], - } + }, + Shareable() ) result = server.client_heartbeat(request) From 95edf8e312bdeed3e0596fbea9fd5909af9c01e2 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 12 Mar 2024 10:22:40 -0400 Subject: [PATCH 38/44] Added CCManager unit tests. --- .../job_schedulers/job_scheduler.py | 3 - .../confidential_computing/cc_manager.py | 4 +- .../confidential_computing/cc_manager_test.py | 133 ++++++++++++++++++ .../private/fed/server/fed_server_test.py | 2 +- 4 files changed, 137 insertions(+), 5 deletions(-) create mode 100644 tests/unit_test/app_opt/confidential_computing/cc_manager_test.py diff --git a/nvflare/app_common/job_schedulers/job_scheduler.py b/nvflare/app_common/job_schedulers/job_scheduler.py index adb40d3558..c8a3d55a2f 100644 --- a/nvflare/app_common/job_schedulers/job_scheduler.py +++ b/nvflare/app_common/job_schedulers/job_scheduler.py @@ -164,9 +164,6 @@ def _try_job(self, job: Job, fl_ctx: FLContext) -> (int, Optional[Dict[str, Disp self.log_info(fl_ctx, f"Job {job.job_id} can't be scheduled: {block_reason}") return SCHEDULE_RESULT_NO_RESOURCE, None, block_reason - PEER_CTX_CC_TOKEN = "_peer_ctx_cc_token" - cc_peer_ctx = fl_ctx.get_prop(key=PEER_CTX_CC_TOKEN) - self.logger.info(f"++++++++++ {cc_peer_ctx}") resource_check_results = self._check_client_resources(job=job, resource_reqs=resource_reqs, fl_ctx=fl_ctx) fl_ctx.set_prop(FLContextKey.CLIENT_RESOURCE_RESULT, resource_check_results, private=True, sticky=False) self.fire_event(EventType.AFTER_CHECK_CLIENT_RESOURCES, fl_ctx) diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index a2b6642722..07bfefa659 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -39,6 +39,8 @@ SHUTDOWN_SYSTEM = 1 SHUTDOWN_JOB = 2 +CC_VERIFICATION_FAILED = "not meeting CC requirements" + class CCManager(FLComponent): def __init__( @@ -327,7 +329,7 @@ def _validate_participants_tokens(self, participants) -> str: if invalid_participant_list: invalid_participant_string = ",".join(invalid_participant_list) self.logger.debug(f"{invalid_participant_list=}") - return f"Participant {invalid_participant_string} not meeting CC requirements" + return f"Participant {invalid_participant_string}" + CC_VERIFICATION_FAILED else: return "" diff --git a/tests/unit_test/app_opt/confidential_computing/cc_manager_test.py b/tests/unit_test/app_opt/confidential_computing/cc_manager_test.py new file mode 100644 index 0000000000..3db086a277 --- /dev/null +++ b/tests/unit_test/app_opt/confidential_computing/cc_manager_test.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from nvflare.apis.fl_constant import ReservedKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.server_engine_spec import ServerEngineSpec +from nvflare.app_opt.confidential_computing.cc_manager import ( + CC_INFO, + CC_NAMESPACE, + CC_TOKEN, + CC_TOKEN_VALIDATED, + CC_VERIFICATION_FAILED, + CCManager, +) +from nvflare.app_opt.confidential_computing.tdx_authorizer import TDX_NAMESPACE, TDXAuthorizer + +VALID_TOKEN = "valid_token" +INVALID_TOKEN = "invalid_token" + + +class TestCCManager: + def setup_method(self, method): + issues_conf = [{"issuer_id": "tdx_authorizer", "token_expiration": 250}] + + verify_ids = (["tdx_authorizer"],) + self.cc_manager = CCManager(issues_conf, verify_ids) + + def test_authorizer_setup(self): + + fl_ctx, tdx_authorizer = self._setup_authorizers() + + assert self.cc_manager.cc_issuers == {tdx_authorizer: 250} + assert self.cc_manager.cc_verifiers == {TDX_NAMESPACE: tdx_authorizer} + + def _setup_authorizers(self): + fl_ctx = Mock(spec=FLContext) + fl_ctx.get_identity_name.return_value = "server" + engine = Mock(spec=ServerEngineSpec) + fl_ctx.get_engine.return_value = engine + + tdx_authorizer = Mock(spec=TDXAuthorizer) + tdx_authorizer.get_namespace.return_value = TDX_NAMESPACE + tdx_authorizer.verify = self._verify_token + engine.get_component.return_value = tdx_authorizer + self.cc_manager._setup_cc_authorizers(fl_ctx) + + tdx_authorizer.generate.return_value = VALID_TOKEN + self.cc_manager._generate_tokens(fl_ctx) + + return fl_ctx, tdx_authorizer + + def _verify_token(self, token): + if token == VALID_TOKEN: + return True + else: + return False + + def test_add_client_token(self): + + cc_info1, cc_info2 = self._add_failed_tokens() + + assert self.cc_manager.participant_cc_info["client1"] == cc_info1 + assert self.cc_manager.participant_cc_info["client2"] == cc_info2 + + def _add_failed_tokens(self): + self.cc_manager._verify_running_jobs = Mock() + client_name = "client1" + valid_token = VALID_TOKEN + cc_info1, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + + client_name = "client2" + valid_token = INVALID_TOKEN + cc_info2, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + return cc_info1, cc_info2 + + def test_verification_success(self): + + self._setup_authorizers() + + self.cc_manager._verify_running_jobs = Mock() + + self.cc_manager._verify_running_jobs = Mock() + client_name = "client1" + valid_token = VALID_TOKEN + cc_info1, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + + client_name = "client2" + valid_token = VALID_TOKEN + cc_info2, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + + self.cc_manager._handle_expired_tokens = Mock() + + err, participant_tokens = self.cc_manager._verify_participants(["client1", "client2"]) + + assert not err + + def test_verification_failed(self): + + self._setup_authorizers() + + self.cc_manager._verify_running_jobs = Mock() + self._add_failed_tokens() + self.cc_manager._handle_expired_tokens = Mock() + + err, participant_tokens = self.cc_manager._verify_participants(["client1", "client2"]) + + assert "client2" in err + assert CC_VERIFICATION_FAILED in err + + def _add_client_token(self, client_name, valid_token): + peer_ctx = FLContext() + cc_info = [{CC_TOKEN: valid_token, CC_NAMESPACE: TDX_NAMESPACE, CC_TOKEN_VALIDATED: False}] + peer_ctx.set_prop(CC_INFO, cc_info) + peer_ctx.set_prop(ReservedKey.IDENTITY_NAME, client_name) + fl_ctx = Mock(spec=FLContext) + fl_ctx.get_peer_context.return_value = peer_ctx + return cc_info, fl_ctx diff --git a/tests/unit_test/private/fed/server/fed_server_test.py b/tests/unit_test/private/fed/server/fed_server_test.py index cda41d824a..235cfac0c9 100644 --- a/tests/unit_test/private/fed/server/fed_server_test.py +++ b/tests/unit_test/private/fed/server/fed_server_test.py @@ -48,7 +48,7 @@ def test_heart_beat_abort_jobs(self, server_state, expected): CellMessageHeaderKeys.PROJECT_NAME: "task_name", CellMessageHeaderKeys.JOB_IDS: ["extra_job"], }, - Shareable() + Shareable(), ) result = server.client_heartbeat(request) From ecaa7c654b96a0a70d3bec1a81e94503ea90d868 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 12 Mar 2024 11:33:48 -0400 Subject: [PATCH 39/44] Added CCTokenGenerateError and CCTokenVerifyError. Updated CCAuthorizer interface. --- .../confidential_computing/cc_authorizer.py | 28 ++++------ .../confidential_computing/cc_manager.py | 56 ++++++++++--------- .../confidential_computing/gpu_authorizer.py | 11 ++-- .../confidential_computing/tdx_authorizer.py | 6 -- .../private/fed/app/client/client_train.py | 1 + .../fed/app/deployer/server_deployer.py | 1 + 6 files changed, 50 insertions(+), 53 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_authorizer.py b/nvflare/app_opt/confidential_computing/cc_authorizer.py index c09a9513a0..1aba63284d 100644 --- a/nvflare/app_opt/confidential_computing/cc_authorizer.py +++ b/nvflare/app_opt/confidential_computing/cc_authorizer.py @@ -15,22 +15,6 @@ class CCAuthorizer: - def can_generate(self) -> bool: - """This indicates if the authorizer can generate a CC token or not. - - Returns: bool - - """ - pass - - def can_verify(self) -> bool: - """This indicates if the authorizer can verify a CC token or not. - - Returns: bool - - """ - pass - def get_namespace(self) -> str: """This returns the namespace of the CCAuthorizer. @@ -57,3 +41,15 @@ def verify(self, token: str) -> bool: """ pass + + +class CCTokenGenerateError(Exception): + """Raised when a CC token generation failed""" + + pass + + +class CCTokenVerifyError(Exception): + """Raised when a CC token verification failed""" + + pass diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 07bfefa659..f7a72312ab 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -20,7 +20,7 @@ from nvflare.apis.fl_constant import FLContextKey, RunProcessKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError -from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer, CCTokenGenerateError, CCTokenVerifyError from nvflare.fuel.hci.conn import Connection from nvflare.private.fed.server.training_cmds import TrainingCommandModule @@ -148,13 +148,13 @@ def _setup_cc_authorizers(self, fl_ctx): issuer_id = conf.get(CC_ISSUER_ID) expiration = conf.get(TOKEN_EXPIRATION) issuer = engine.get_component(issuer_id) - if not (isinstance(issuer, CCAuthorizer) and issuer.can_generate()): + if not isinstance(issuer, CCAuthorizer): raise RuntimeError(f"cc_issuer_id {issuer_id} must be a CCAuthorizer, but got {issuer.__class__}") self.cc_issuers[issuer] = expiration for v_id in self.cc_verifier_ids: verifier = engine.get_component(v_id) - if not (isinstance(verifier, CCAuthorizer) and verifier.can_verify()): + if not isinstance(verifier, CCAuthorizer): raise RuntimeError(f"cc_authorizer_id {v_id} must be a CCAuthorizer, but got {verifier.__class__}") namespace = verifier.get_namespace() if namespace in self.cc_verifiers.keys(): @@ -222,25 +222,28 @@ def _generate_tokens(self, fl_ctx: FLContext) -> str: self.participant_cc_info[self.site_name] = [] for issuer, expiration in self.cc_issuers.items(): - my_token = issuer.generate() - namespace = issuer.get_namespace() - - if not isinstance(expiration, int): - raise ValueError(f"token_expiration value must be int, but got {expiration.__class__}") - if not my_token: - return "failed to get CC token" - - self.logger.info(f"site: {self.site_name} namespace: {namespace} got the token: {my_token}") - cc_info = { - CC_TOKEN: my_token, - CC_ISSUER: issuer, - CC_NAMESPACE: namespace, - TOKEN_GENERATION_TIME: time.time(), - TOKEN_EXPIRATION: int(expiration), - CC_TOKEN_VALIDATED: True, - } - self.participant_cc_info[self.site_name].append(cc_info) - self.token_submitted = False + try: + my_token = issuer.generate() + namespace = issuer.get_namespace() + + if not isinstance(expiration, int): + raise ValueError(f"token_expiration value must be int, but got {expiration.__class__}") + if not my_token: + return f"{issuer} failed to get CC token" + + self.logger.info(f"site: {self.site_name} namespace: {namespace} got the token: {my_token}") + cc_info = { + CC_TOKEN: my_token, + CC_ISSUER: issuer, + CC_NAMESPACE: namespace, + TOKEN_GENERATION_TIME: time.time(), + TOKEN_EXPIRATION: int(expiration), + CC_TOKEN_VALIDATED: True, + } + self.participant_cc_info[self.site_name].append(cc_info) + self.token_submitted = False + except CCTokenGenerateError: + raise RuntimeError(f"{issuer} failed to generate CC token.") return "" @@ -343,9 +346,12 @@ def _validate_participants(self, participants: Dict[str, List[Dict[str, str]]]) token = v.get(CC_TOKEN, "") namespace = v.get(CC_NAMESPACE, "") verifier = self.cc_verifiers.get(namespace, None) - if verifier and verifier.verify(token): - result[k + "." + namespace] = True - else: + try: + if verifier and verifier.verify(token): + result[k + "." + namespace] = True + else: + invalid_participant_list.append(k + " namespace: {" + namespace + "}") + except CCTokenVerifyError: invalid_participant_list.append(k + " namespace: {" + namespace + "}") self.logger.info(f"CC - results from validating participants' tokens: {result}") return result, invalid_participant_list diff --git a/nvflare/app_opt/confidential_computing/gpu_authorizer.py b/nvflare/app_opt/confidential_computing/gpu_authorizer.py index 73cf8c7682..e11d1ddba0 100644 --- a/nvflare/app_opt/confidential_computing/gpu_authorizer.py +++ b/nvflare/app_opt/confidential_computing/gpu_authorizer.py @@ -19,6 +19,11 @@ class GPUAuthorizer(CCAuthorizer): + """Note: This is just a fake implementation for GPU authorizer. It will be replaced later + with the real implementation. + + """ + def __init__(self, verifiers: list) -> None: """ @@ -78,12 +83,6 @@ def __init__(self, verifiers: list) -> None: super().__init__() self.verifiers = verifiers - def can_generate(self) -> bool: - return True - - def can_verify(self) -> bool: - return True - def get_namespace(self) -> str: return GPU_NAMESPACE diff --git a/nvflare/app_opt/confidential_computing/tdx_authorizer.py b/nvflare/app_opt/confidential_computing/tdx_authorizer.py index 3825588dd8..87bf31c991 100644 --- a/nvflare/app_opt/confidential_computing/tdx_authorizer.py +++ b/nvflare/app_opt/confidential_computing/tdx_authorizer.py @@ -63,11 +63,5 @@ def verify(self, token: str) -> bool: return True - def can_generate(self) -> bool: - return True - - def can_verify(self) -> bool: - return True - def get_namespace(self) -> str: return TDX_NAMESPACE diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index adb5a4fac5..6b42767ecc 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -116,6 +116,7 @@ def main(args): if exceptions: for _, exception in exceptions.items(): if isinstance(exception, UnsafeComponentError): + print("Unsafe component configured, could not start the client!!") raise RuntimeError(exception) client_engine.fire_event(EventType.BEFORE_CLIENT_REGISTER, fl_ctx) diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index 40924392b8..6c3dd4acae 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -127,6 +127,7 @@ def deploy(self, args): if exceptions: for _, exception in exceptions.items(): if isinstance(exception, UnsafeComponentError): + print("Unsafe component configured, could not start the server!!") raise RuntimeError(exception) threading.Thread(target=self._start_job_runner, args=[job_runner, fl_ctx]).start() From 60a5c9d5a1bd0fb42d8e3a29c151f811b021cce4 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Thu, 14 Mar 2024 13:45:07 -0400 Subject: [PATCH 40/44] Addressed some PR reviews. --- .../confidential_computing/cc_authorizer.py | 6 +++++- .../app_opt/confidential_computing/cc_manager.py | 5 +++++ .../confidential_computing/gpu_authorizer.py | 4 ++-- .../confidential_computing/tdx_authorizer.py | 14 ++++++++++---- nvflare/private/fed/app/client/client_train.py | 10 ++-------- .../private/fed/app/deployer/server_deployer.py | 9 ++------- nvflare/private/fed/app/utils.py | 12 ++++++++++++ 7 files changed, 38 insertions(+), 22 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/cc_authorizer.py b/nvflare/app_opt/confidential_computing/cc_authorizer.py index 1aba63284d..aaf610d88c 100644 --- a/nvflare/app_opt/confidential_computing/cc_authorizer.py +++ b/nvflare/app_opt/confidential_computing/cc_authorizer.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # import os.path +from abc import ABC, abstractmethod -class CCAuthorizer: +class CCAuthorizer(ABC): + @abstractmethod def get_namespace(self) -> str: """This returns the namespace of the CCAuthorizer. @@ -23,6 +25,7 @@ def get_namespace(self) -> str: """ pass + @abstractmethod def generate(self) -> str: """To generate and return the active CCAuthorizer token. @@ -31,6 +34,7 @@ def generate(self) -> str: """ pass + @abstractmethod def verify(self, token: str) -> bool: """To return the token verification result. diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index f7a72312ab..fe8e1442a2 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -61,6 +61,11 @@ def __init__( shutdown the running jobs if CC tokens expired Args: + cc_issuers_conf: configuration of the CC token issuers. each contains the CC token issuer component ID, + and the token expiration time + cc_verifier_ids: CC token verifiers component IDs + verify_frequency: CC tokens verification frequency + critical_level: critical_level """ FLComponent.__init__(self) diff --git a/nvflare/app_opt/confidential_computing/gpu_authorizer.py b/nvflare/app_opt/confidential_computing/gpu_authorizer.py index e11d1ddba0..bd55e2a463 100644 --- a/nvflare/app_opt/confidential_computing/gpu_authorizer.py +++ b/nvflare/app_opt/confidential_computing/gpu_authorizer.py @@ -87,7 +87,7 @@ def get_namespace(self) -> str: return GPU_NAMESPACE def generate(self) -> str: - return super().generate() + raise NotImplementedError def verify(self, token: str) -> bool: - return super().verify(token) + raise NotImplementedError diff --git a/nvflare/app_opt/confidential_computing/tdx_authorizer.py b/nvflare/app_opt/confidential_computing/tdx_authorizer.py index 87bf31c991..b3f74e386c 100644 --- a/nvflare/app_opt/confidential_computing/tdx_authorizer.py +++ b/nvflare/app_opt/confidential_computing/tdx_authorizer.py @@ -41,12 +41,15 @@ def generate(self) -> str: command = ["sudo", self.tdx_cli_command, "-c", self.config_file, "token", "--no-eventlog"] subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) - with open(error_file, "r") as f: - if "Error:" in f.read(): + if not os.path.exists(error_file) or not os.path.exists(token_file): + return "" + + with open(error_file, "r") as e_f: + if "Error:" in e_f.read(): return "" else: - with open(token_file, "r") as f: - token = f.readline() + with open(token_file, "r") as t_f: + token = t_f.readline() return token def verify(self, token: str) -> bool: @@ -57,6 +60,9 @@ def verify(self, token: str) -> bool: command = [self.tdx_cli_command, "verify", "--config", self.config_file, "--token", token] subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) + if not os.path.exists(error_file): + return False + with open(error_file, "r") as f: if "Error:" in f.read(): return False diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index 6b42767ecc..2618974a81 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -21,14 +21,13 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey, JobConstants, SiteType, WorkspaceConstants -from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.apis.workspace import Workspace from nvflare.fuel.common.excepts import ConfigError from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.private.defs import AppFolderConstants from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger, create_privacy_manager -from nvflare.private.fed.app.utils import version_check +from nvflare.private.fed.app.utils import component_security_check, version_check from nvflare.private.fed.client.admin import FedAdminAgent from nvflare.private.fed.client.client_engine import ClientEngine from nvflare.private.fed.client.client_status import ClientStatus @@ -112,12 +111,7 @@ def main(args): fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) - exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) - if exceptions: - for _, exception in exceptions.items(): - if isinstance(exception, UnsafeComponentError): - print("Unsafe component configured, could not start the client!!") - raise RuntimeError(exception) + component_security_check(fl_ctx) client_engine.fire_event(EventType.BEFORE_CLIENT_REGISTER, fl_ctx) federated_client.register(fl_ctx) diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index 6c3dd4acae..ecb2a1b04f 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -17,8 +17,8 @@ from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import FLContextKey, SystemComponents -from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.apis.workspace import Workspace +from nvflare.private.fed.app.utils import component_security_check from nvflare.private.fed.server.fed_server import FederatedServer from nvflare.private.fed.server.job_runner import JobRunner from nvflare.private.fed.server.run_manager import RunManager @@ -123,12 +123,7 @@ def deploy(self, args): fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) services.engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) - exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) - if exceptions: - for _, exception in exceptions.items(): - if isinstance(exception, UnsafeComponentError): - print("Unsafe component configured, could not start the server!!") - raise RuntimeError(exception) + component_security_check(fl_ctx) threading.Thread(target=self._start_job_runner, args=[job_runner, fl_ctx]).start() services.status = ServerStatus.STARTED diff --git a/nvflare/private/fed/app/utils.py b/nvflare/private/fed/app/utils.py index 8552ec4945..94712d4b21 100644 --- a/nvflare/private/fed/app/utils.py +++ b/nvflare/private/fed/app/utils.py @@ -20,6 +20,9 @@ import psutil +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.fuel.hci.security import hash_password from nvflare.private.defs import SSLConstants from nvflare.private.fed.runner import Runner @@ -98,3 +101,12 @@ def version_check(): raise RuntimeError("Python versions 3.11 and above are not yet supported. Please use Python 3.8, 3.9 or 3.10.") if sys.version_info < (3, 8): raise RuntimeError("Python versions 3.7 and below are not supported. Please use Python 3.8, 3.9 or 3.10") + + +def component_security_check(fl_ctx: FLContext): + exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) + if exceptions: + for _, exception in exceptions.items(): + if isinstance(exception, UnsafeComponentError): + print(f"Unsafe component configured, could not start {fl_ctx.get_identity_name()}!!") + raise RuntimeError(exception) From 9244fbf3cfef9f1e04a40142928c5bdffa9dd77f Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 15 Mar 2024 11:47:28 -0400 Subject: [PATCH 41/44] Added exception catch for TDXAuthorizer. --- .../confidential_computing/tdx_authorizer.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/nvflare/app_opt/confidential_computing/tdx_authorizer.py b/nvflare/app_opt/confidential_computing/tdx_authorizer.py index b3f74e386c..21bff9035e 100644 --- a/nvflare/app_opt/confidential_computing/tdx_authorizer.py +++ b/nvflare/app_opt/confidential_computing/tdx_authorizer.py @@ -44,13 +44,16 @@ def generate(self) -> str: if not os.path.exists(error_file) or not os.path.exists(token_file): return "" - with open(error_file, "r") as e_f: - if "Error:" in e_f.read(): - return "" - else: - with open(token_file, "r") as t_f: - token = t_f.readline() - return token + try: + with open(error_file, "r") as e_f: + if "Error:" in e_f.read(): + return "" + else: + with open(token_file, "r") as t_f: + token = t_f.readline() + return token + except: + return "" def verify(self, token: str) -> bool: out = open(os.path.join(self.config_dir, VERIFY_FILE), "w") @@ -63,9 +66,12 @@ def verify(self, token: str) -> bool: if not os.path.exists(error_file): return False - with open(error_file, "r") as f: - if "Error:" in f.read(): - return False + try: + with open(error_file, "r") as f: + if "Error:" in f.read(): + return False + except: + return False return True From dbc87d147147f1bb1cddb80739342aef645e6356 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Mon, 18 Mar 2024 13:20:38 -0400 Subject: [PATCH 42/44] renamed some events. --- .../security/server/custom/security_handler.py | 2 +- nvflare/apis/event_type.py | 5 ++++- nvflare/app_opt/confidential_computing/cc_manager.py | 2 +- nvflare/private/fed/client/communicator.py | 1 + nvflare/private/fed/server/fed_server.py | 6 ++++-- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/examples/advanced/custom_authentication/security/server/custom/security_handler.py b/examples/advanced/custom_authentication/security/server/custom/security_handler.py index 4b5956ff67..bd3c1b4165 100644 --- a/examples/advanced/custom_authentication/security/server/custom/security_handler.py +++ b/examples/advanced/custom_authentication/security/server/custom/security_handler.py @@ -20,7 +20,7 @@ class ServerCustomSecurityHandler(FLComponent): def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.CLIENT_REGISTERED: + if event_type == EventType.RECEIVED_CLIENT_REGISTER: self.authenticate(fl_ctx=fl_ctx) def authenticate(self, fl_ctx: FLContext): diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 6e473aff4d..79b01a60e6 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -73,11 +73,14 @@ class EventType(object): BEFORE_CLIENT_REGISTER = "_before_client_register" AFTER_CLIENT_REGISTER = "_after_client_register" - CLIENT_REGISTERED = "_client_registered" + RECEIVED_CLIENT_REGISTER = "_received_client_register" + PROCEEDED_CLIENT_REGISTER = "_proceeded_client_register" CLIENT_QUIT = "_client_quit" SYSTEM_BOOTSTRAP = "_system_bootstrap" BEFORE_CLIENT_HEARTBEAT = "_before_client_heartbeat" AFTER_CLIENT_HEARTBEAT = "_after_client_heartbeat" + RECEIVED_CLIENT_HEARTBEAT = "_received_client_heartbeat" + PROCEEDED_CLIENT_HEARTBEAT = "_proceeded_client_heartbeat" AUTHORIZE_COMMAND_CHECK = "_authorize_command_check" BEFORE_BUILD_COMPONENT = "_before_build_component" diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index fe8e1442a2..da566382ed 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -105,7 +105,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): elif event_type == EventType.BEFORE_CLIENT_REGISTER or event_type == EventType.BEFORE_CLIENT_HEARTBEAT: # On client side self._prepare_cc_info(fl_ctx) - elif event_type == EventType.CLIENT_REGISTERED or event_type == EventType.AFTER_CLIENT_HEARTBEAT: + elif event_type == EventType.RECEIVED_CLIENT_REGISTER or event_type == EventType.RECEIVED_CLIENT_HEARTBEAT: # Server side self._add_client_token(fl_ctx) elif event_type == EventType.CLIENT_QUIT: diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 1a94384d5d..e95e2f5269 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -378,6 +378,7 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C except Exception as ex: raise FLCommunicationError("error:client_quit", ex) + engine.fire_event(EventType.AFTER_CLIENT_HEARTBEAT, fl_ctx) for i in range(wait_times): time.sleep(2) if self.heartbeat_done: diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index e87c04e6a4..f8225d62c4 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -493,7 +493,7 @@ def register_client(self, request: Message) -> Message: shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) fl_ctx.set_peer_context(shared_fl_ctx) - self.engine.fire_event(EventType.CLIENT_REGISTERED, fl_ctx=fl_ctx) + self.engine.fire_event(EventType.RECEIVED_CLIENT_REGISTER, fl_ctx=fl_ctx) exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) if exceptions: @@ -513,6 +513,7 @@ def register_client(self, request: Message) -> Message: } else: headers = {} + self.engine.fire_event(EventType.PROCEEDED_CLIENT_REGISTER, fl_ctx=fl_ctx) return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx) except NotAuthenticated as e: self.logger.error(f"Failed to authenticate the register_client: {secure_format_exception(e)}") @@ -580,7 +581,7 @@ def client_heartbeat(self, request: Message) -> Message: data = request.payload shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) fl_ctx.set_peer_context(shared_fl_ctx) - self.engine.fire_event(EventType.AFTER_CLIENT_HEARTBEAT, fl_ctx=fl_ctx) + self.engine.fire_event(EventType.RECEIVED_CLIENT_HEARTBEAT, fl_ctx=fl_ctx) token = request.get_header(CellMessageHeaderKeys.TOKEN) client_name = request.get_header(CellMessageHeaderKeys.CLIENT_NAME) @@ -603,6 +604,7 @@ def client_heartbeat(self, request: Message) -> Message: f"These jobs: {display_runs} are not running on the server. " f"Ask client: {client_name} to abort these runs." ) + self.engine.fire_event(EventType.PROCEEDED_CLIENT_HEARTBEAT, fl_ctx=fl_ctx) return reply def _sync_client_jobs(self, request, client_token): From 78b52e6af544df19e7e9da1ddac2a66312ec1c1c Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 19 Mar 2024 09:28:56 -0400 Subject: [PATCH 43/44] renamed event names. --- nvflare/apis/event_type.py | 4 ++-- nvflare/apis/fl_constant.py | 2 +- nvflare/app_common/job_schedulers/job_scheduler.py | 2 +- nvflare/app_opt/confidential_computing/cc_manager.py | 4 ++-- nvflare/private/fed/server/fed_server.py | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 79b01a60e6..8f476f2a1a 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -79,8 +79,8 @@ class EventType(object): SYSTEM_BOOTSTRAP = "_system_bootstrap" BEFORE_CLIENT_HEARTBEAT = "_before_client_heartbeat" AFTER_CLIENT_HEARTBEAT = "_after_client_heartbeat" - RECEIVED_CLIENT_HEARTBEAT = "_received_client_heartbeat" - PROCEEDED_CLIENT_HEARTBEAT = "_proceeded_client_heartbeat" + CLIENT_HEARTBEAT_RECEIVED = "_client_heartbeat_received" + CLIENT_HEARTBEAT_PROCESSED = "_client_heartbeat_processed" AUTHORIZE_COMMAND_CHECK = "_authorize_command_check" BEFORE_BUILD_COMPONENT = "_before_build_component" diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 1daafa4619..325050aac3 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -150,7 +150,7 @@ class FLContextKey(object): COMMUNICATION_ERROR = "Flare_communication_error__" UNAUTHENTICATED = "Flare_unauthenticated__" CLIENT_RESOURCE_SPECS = "__client_resource_specs" - CLIENT_RESOURCE_RESULT = "__client_resource_result" + RESOURCE_CHECK_RESULT = "__resource_check_result" JOB_PARTICIPANTS = "__job_participants" JOB_BLOCK_REASON = "__job_block_reason" # why the job should be blocked from scheduling SSID = "__ssid__" diff --git a/nvflare/app_common/job_schedulers/job_scheduler.py b/nvflare/app_common/job_schedulers/job_scheduler.py index c8a3d55a2f..c7e03d394f 100644 --- a/nvflare/app_common/job_schedulers/job_scheduler.py +++ b/nvflare/app_common/job_schedulers/job_scheduler.py @@ -165,7 +165,7 @@ def _try_job(self, job: Job, fl_ctx: FLContext) -> (int, Optional[Dict[str, Disp return SCHEDULE_RESULT_NO_RESOURCE, None, block_reason resource_check_results = self._check_client_resources(job=job, resource_reqs=resource_reqs, fl_ctx=fl_ctx) - fl_ctx.set_prop(FLContextKey.CLIENT_RESOURCE_RESULT, resource_check_results, private=True, sticky=False) + fl_ctx.set_prop(FLContextKey.RESOURCE_CHECK_RESULT, resource_check_results, private=True, sticky=False) self.fire_event(EventType.AFTER_CHECK_CLIENT_RESOURCES, fl_ctx) if not resource_check_results: diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index da566382ed..e1fe85a0f1 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -105,7 +105,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): elif event_type == EventType.BEFORE_CLIENT_REGISTER or event_type == EventType.BEFORE_CLIENT_HEARTBEAT: # On client side self._prepare_cc_info(fl_ctx) - elif event_type == EventType.RECEIVED_CLIENT_REGISTER or event_type == EventType.RECEIVED_CLIENT_HEARTBEAT: + elif event_type == EventType.RECEIVED_CLIENT_REGISTER or event_type == EventType.CLIENT_HEARTBEAT_RECEIVED: # Server side self._add_client_token(fl_ctx) elif event_type == EventType.CLIENT_QUIT: @@ -135,7 +135,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): else: threading.Thread(target=self._shutdown_system, args=[err, fl_ctx]).start() elif event_type == EventType.AFTER_CHECK_CLIENT_RESOURCES: - client_resource_result = fl_ctx.get_prop(FLContextKey.CLIENT_RESOURCE_RESULT) + client_resource_result = fl_ctx.get_prop(FLContextKey.RESOURCE_CHECK_RESULT) if client_resource_result: for site_name, check_result in client_resource_result.items(): is_resource_enough, reason = check_result diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index f8225d62c4..7f10f2cc80 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -581,7 +581,7 @@ def client_heartbeat(self, request: Message) -> Message: data = request.payload shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) fl_ctx.set_peer_context(shared_fl_ctx) - self.engine.fire_event(EventType.RECEIVED_CLIENT_HEARTBEAT, fl_ctx=fl_ctx) + self.engine.fire_event(EventType.CLIENT_HEARTBEAT_RECEIVED, fl_ctx=fl_ctx) token = request.get_header(CellMessageHeaderKeys.TOKEN) client_name = request.get_header(CellMessageHeaderKeys.CLIENT_NAME) @@ -604,7 +604,7 @@ def client_heartbeat(self, request: Message) -> Message: f"These jobs: {display_runs} are not running on the server. " f"Ask client: {client_name} to abort these runs." ) - self.engine.fire_event(EventType.PROCEEDED_CLIENT_HEARTBEAT, fl_ctx=fl_ctx) + self.engine.fire_event(EventType.CLIENT_HEARTBEAT_PROCESSED, fl_ctx=fl_ctx) return reply def _sync_client_jobs(self, request, client_token): From 108a1ec19860955d66b54321de1195fe1889b859 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 19 Mar 2024 12:19:05 -0400 Subject: [PATCH 44/44] renamed event names. --- .../security/server/custom/security_handler.py | 2 +- nvflare/apis/event_type.py | 4 ++-- nvflare/app_opt/confidential_computing/cc_manager.py | 2 +- nvflare/private/fed/server/fed_server.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/advanced/custom_authentication/security/server/custom/security_handler.py b/examples/advanced/custom_authentication/security/server/custom/security_handler.py index bd3c1b4165..c10ecb07ec 100644 --- a/examples/advanced/custom_authentication/security/server/custom/security_handler.py +++ b/examples/advanced/custom_authentication/security/server/custom/security_handler.py @@ -20,7 +20,7 @@ class ServerCustomSecurityHandler(FLComponent): def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.RECEIVED_CLIENT_REGISTER: + if event_type == EventType.CLIENT_REGISTER_RECEIVED: self.authenticate(fl_ctx=fl_ctx) def authenticate(self, fl_ctx: FLContext): diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 8f476f2a1a..1090961ada 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -73,8 +73,8 @@ class EventType(object): BEFORE_CLIENT_REGISTER = "_before_client_register" AFTER_CLIENT_REGISTER = "_after_client_register" - RECEIVED_CLIENT_REGISTER = "_received_client_register" - PROCEEDED_CLIENT_REGISTER = "_proceeded_client_register" + CLIENT_REGISTER_RECEIVED = "_client_register_received" + CLIENT_REGISTER_PROCESSED = "_client_register_processed" CLIENT_QUIT = "_client_quit" SYSTEM_BOOTSTRAP = "_system_bootstrap" BEFORE_CLIENT_HEARTBEAT = "_before_client_heartbeat" diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index e1fe85a0f1..052a61f5b4 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -105,7 +105,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): elif event_type == EventType.BEFORE_CLIENT_REGISTER or event_type == EventType.BEFORE_CLIENT_HEARTBEAT: # On client side self._prepare_cc_info(fl_ctx) - elif event_type == EventType.RECEIVED_CLIENT_REGISTER or event_type == EventType.CLIENT_HEARTBEAT_RECEIVED: + elif event_type == EventType.CLIENT_REGISTER_RECEIVED or event_type == EventType.CLIENT_HEARTBEAT_RECEIVED: # Server side self._add_client_token(fl_ctx) elif event_type == EventType.CLIENT_QUIT: diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 7f10f2cc80..ae50e4c3a6 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -493,7 +493,7 @@ def register_client(self, request: Message) -> Message: shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) fl_ctx.set_peer_context(shared_fl_ctx) - self.engine.fire_event(EventType.RECEIVED_CLIENT_REGISTER, fl_ctx=fl_ctx) + self.engine.fire_event(EventType.CLIENT_REGISTER_RECEIVED, fl_ctx=fl_ctx) exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) if exceptions: @@ -513,7 +513,7 @@ def register_client(self, request: Message) -> Message: } else: headers = {} - self.engine.fire_event(EventType.PROCEEDED_CLIENT_REGISTER, fl_ctx=fl_ctx) + self.engine.fire_event(EventType.CLIENT_REGISTER_PROCESSED, fl_ctx=fl_ctx) return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx) except NotAuthenticated as e: self.logger.error(f"Failed to authenticate the register_client: {secure_format_exception(e)}")