diff --git a/examples/advanced/custom_authentication/security/server/custom/security_handler.py b/examples/advanced/custom_authentication/security/server/custom/security_handler.py index 4b5956ff67..c10ecb07ec 100644 --- a/examples/advanced/custom_authentication/security/server/custom/security_handler.py +++ b/examples/advanced/custom_authentication/security/server/custom/security_handler.py @@ -20,7 +20,7 @@ class ServerCustomSecurityHandler(FLComponent): def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.CLIENT_REGISTERED: + if event_type == EventType.CLIENT_REGISTER_RECEIVED: self.authenticate(fl_ctx=fl_ctx) def authenticate(self, fl_ctx: FLContext): diff --git a/nvflare/apis/event_type.py b/nvflare/apis/event_type.py index 4c65843e70..1090961ada 100644 --- a/nvflare/apis/event_type.py +++ b/nvflare/apis/event_type.py @@ -73,8 +73,14 @@ class EventType(object): BEFORE_CLIENT_REGISTER = "_before_client_register" AFTER_CLIENT_REGISTER = "_after_client_register" - CLIENT_REGISTERED = "_client_registered" + CLIENT_REGISTER_RECEIVED = "_client_register_received" + CLIENT_REGISTER_PROCESSED = "_client_register_processed" + CLIENT_QUIT = "_client_quit" SYSTEM_BOOTSTRAP = "_system_bootstrap" + BEFORE_CLIENT_HEARTBEAT = "_before_client_heartbeat" + AFTER_CLIENT_HEARTBEAT = "_after_client_heartbeat" + CLIENT_HEARTBEAT_RECEIVED = "_client_heartbeat_received" + CLIENT_HEARTBEAT_PROCESSED = "_client_heartbeat_processed" AUTHORIZE_COMMAND_CHECK = "_authorize_command_check" BEFORE_BUILD_COMPONENT = "_before_build_component" diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index af15ac6504..325050aac3 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -150,6 +150,7 @@ class FLContextKey(object): COMMUNICATION_ERROR = "Flare_communication_error__" UNAUTHENTICATED = "Flare_unauthenticated__" CLIENT_RESOURCE_SPECS = "__client_resource_specs" + RESOURCE_CHECK_RESULT = "__resource_check_result" JOB_PARTICIPANTS = "__job_participants" JOB_BLOCK_REASON = "__job_block_reason" # why the job should be blocked from scheduling SSID = "__ssid__" diff --git a/nvflare/apis/server_engine_spec.py b/nvflare/apis/server_engine_spec.py index 1b1e8c14c4..004e623056 100644 --- a/nvflare/apis/server_engine_spec.py +++ b/nvflare/apis/server_engine_spec.py @@ -154,12 +154,13 @@ def restore_components(self, snapshot: RunSnapshot, fl_ctx: FLContext): pass @abstractmethod - def start_client_job(self, job_id, client_sites): + def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): """To send the start client run commands to the clients Args: client_sites: client sites job_id: job_id + fl_ctx: FLContext Returns: @@ -187,7 +188,7 @@ def check_client_resources( @abstractmethod def cancel_client_resources( - self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] + self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext ): """Cancels the request resources for the job. @@ -195,6 +196,7 @@ def cancel_client_resources( resource_check_results: A dict of {client_name: client_check_result} where client_check_result is a tuple of (is_resource_enough, resource reserve token if any) resource_reqs: A dict of {client_name: resource requirements dict} + fl_ctx: FLContext """ pass diff --git a/nvflare/app_common/job_schedulers/job_scheduler.py b/nvflare/app_common/job_schedulers/job_scheduler.py index 2619b98d68..c7e03d394f 100644 --- a/nvflare/app_common/job_schedulers/job_scheduler.py +++ b/nvflare/app_common/job_schedulers/job_scheduler.py @@ -93,7 +93,7 @@ def _cancel_resources( if not isinstance(engine, ServerEngineSpec): raise RuntimeError(f"engine inside fl_ctx should be of type ServerEngineSpec, but got {type(engine)}.") - engine.cancel_client_resources(resource_check_results, resource_reqs) + engine.cancel_client_resources(resource_check_results, resource_reqs, fl_ctx) self.log_debug(fl_ctx, f"cancel client resources using check results: {resource_check_results}") return False, None @@ -165,6 +165,7 @@ def _try_job(self, job: Job, fl_ctx: FLContext) -> (int, Optional[Dict[str, Disp return SCHEDULE_RESULT_NO_RESOURCE, None, block_reason resource_check_results = self._check_client_resources(job=job, resource_reqs=resource_reqs, fl_ctx=fl_ctx) + fl_ctx.set_prop(FLContextKey.RESOURCE_CHECK_RESULT, resource_check_results, private=True, sticky=False) self.fire_event(EventType.AFTER_CHECK_CLIENT_RESOURCES, fl_ctx) if not resource_check_results: diff --git a/nvflare/app_opt/confidential_computing/cc_authorizer.py b/nvflare/app_opt/confidential_computing/cc_authorizer.py new file mode 100644 index 0000000000..aaf610d88c --- /dev/null +++ b/nvflare/app_opt/confidential_computing/cc_authorizer.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import os.path +from abc import ABC, abstractmethod + + +class CCAuthorizer(ABC): + @abstractmethod + def get_namespace(self) -> str: + """This returns the namespace of the CCAuthorizer. + + Returns: namespace string + + """ + pass + + @abstractmethod + def generate(self) -> str: + """To generate and return the active CCAuthorizer token. + + Returns: token string + + """ + pass + + @abstractmethod + def verify(self, token: str) -> bool: + """To return the token verification result. + + Args: + token: bool + + Returns: + + """ + pass + + +class CCTokenGenerateError(Exception): + """Raised when a CC token generation failed""" + + pass + + +class CCTokenVerifyError(Exception): + """Raised when a CC token verification failed""" + + pass diff --git a/nvflare/app_opt/confidential_computing/cc_manager.py b/nvflare/app_opt/confidential_computing/cc_manager.py index 82fa7dba8a..052a61f5b4 100644 --- a/nvflare/app_opt/confidential_computing/cc_manager.py +++ b/nvflare/app_opt/confidential_computing/cc_manager.py @@ -11,95 +11,90 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import threading +import time +from typing import Dict, List from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey +from nvflare.apis.fl_constant import FLContextKey, RunProcessKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError - -from .cc_helper import CCHelper +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer, CCTokenGenerateError, CCTokenVerifyError +from nvflare.fuel.hci.conn import Connection +from nvflare.private.fed.server.training_cmds import TrainingCommandModule PEER_CTX_CC_TOKEN = "_peer_ctx_cc_token" CC_TOKEN = "_cc_token" +CC_ISSUER = "_cc_issuer" +CC_NAMESPACE = "_cc_namespace" CC_INFO = "_cc_info" CC_TOKEN_VALIDATED = "_cc_token_validated" +CC_VERIFY_ERROR = "_cc_verify_error." + +CC_ISSUER_ID = "issuer_id" +TOKEN_GENERATION_TIME = "token_generation_time" +TOKEN_EXPIRATION = "token_expiration" + +SHUTDOWN_SYSTEM = 1 +SHUTDOWN_JOB = 2 + +CC_VERIFICATION_FAILED = "not meeting CC requirements" class CCManager(FLComponent): - def __init__(self, verifiers: list): + def __init__( + self, + cc_issuers_conf: [Dict[str, str]], + cc_verifier_ids: [str], + verify_frequency=600, + critical_level=SHUTDOWN_JOB, + ): """Manage all confidential computing related tasks. This manager does the following tasks: - obtaining its own GPU CC token + obtaining its own CC token preparing the token to the server keeping clients' tokens in server validating all tokens in the entire NVFlare system + not allowing the system to start if failed to get CC token + shutdown the running jobs if CC tokens expired Args: - verifiers (list): - each element in this list is a dictionary and the keys of dictionary are - "devices", "env", "url", "appraisal_policy_file" and "result_policy_file." - - the values of devices are "gpu" and "cpu" - the values of env are "local" and "test" - currently, valid combination is gpu + local - - url must be an empty string - appraisal_policy_file must point to an existing file - currently supports an empty file only - - result_policy_file must point to an existing file - currently supports the following content only - - .. code-block:: json - - { - "version":"1.0", - "authorization-rules":{ - "x-nv-gpu-available":true, - "x-nv-gpu-attestation-report-available":true, - "x-nv-gpu-info-fetched":true, - "x-nv-gpu-arch-check":true, - "x-nv-gpu-root-cert-available":true, - "x-nv-gpu-cert-chain-verified":true, - "x-nv-gpu-ocsp-cert-chain-verified":true, - "x-nv-gpu-ocsp-signature-verified":true, - "x-nv-gpu-cert-ocsp-nonce-match":true, - "x-nv-gpu-cert-check-complete":true, - "x-nv-gpu-measurement-available":true, - "x-nv-gpu-attestation-report-parsed":true, - "x-nv-gpu-nonce-match":true, - "x-nv-gpu-attestation-report-driver-version-match":true, - "x-nv-gpu-attestation-report-vbios-version-match":true, - "x-nv-gpu-attestation-report-verified":true, - "x-nv-gpu-driver-rim-schema-fetched":true, - "x-nv-gpu-driver-rim-schema-validated":true, - "x-nv-gpu-driver-rim-cert-extracted":true, - "x-nv-gpu-driver-rim-signature-verified":true, - "x-nv-gpu-driver-rim-driver-measurements-available":true, - "x-nv-gpu-driver-vbios-rim-fetched":true, - "x-nv-gpu-vbios-rim-schema-validated":true, - "x-nv-gpu-vbios-rim-cert-extracted":true, - "x-nv-gpu-vbios-rim-signature-verified":true, - "x-nv-gpu-vbios-rim-driver-measurements-available":true, - "x-nv-gpu-vbios-index-conflict":true, - "x-nv-gpu-measurements-match":true - } - } + cc_issuers_conf: configuration of the CC token issuers. each contains the CC token issuer component ID, + and the token expiration time + cc_verifier_ids: CC token verifiers component IDs + verify_frequency: CC tokens verification frequency + critical_level: critical_level """ FLComponent.__init__(self) self.site_name = None - self.helper = None - self.verifiers = verifiers - self.my_token = None + self.cc_issuers_conf = cc_issuers_conf + self.cc_verifier_ids = cc_verifier_ids + + if not isinstance(verify_frequency, int): + raise ValueError(f"verify_frequency must be in, but got {verify_frequency.__class__}") + self.verify_frequency = int(verify_frequency) + + self.critical_level = critical_level + if self.critical_level not in [SHUTDOWN_SYSTEM, SHUTDOWN_JOB]: + raise ValueError(f"critical_level must be in [{SHUTDOWN_SYSTEM}, {SHUTDOWN_JOB}]. But got {critical_level}") + + self.verify_time = None + self.cc_issuers = {} + self.cc_verifiers = {} self.participant_cc_info = {} # used by the Server to keep tokens of all clients + self.token_submitted = False + self.lock = threading.Lock() + def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.SYSTEM_BOOTSTRAP: try: - err = self._prepare_for_attestation(fl_ctx) + self._setup_cc_authorizers(fl_ctx) + + err = self._generate_tokens(fl_ctx) except: self.log_exception(fl_ctx, "exception in attestation preparation") err = "exception in attestation preparation" @@ -107,26 +102,27 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): if err: self.log_critical(fl_ctx, err, fire_event=False) raise UnsafeComponentError(err) - elif event_type == EventType.BEFORE_CLIENT_REGISTER: + elif event_type == EventType.BEFORE_CLIENT_REGISTER or event_type == EventType.BEFORE_CLIENT_HEARTBEAT: # On client side - self._prepare_token_for_login(fl_ctx) - elif event_type == EventType.CLIENT_REGISTERED: + self._prepare_cc_info(fl_ctx) + elif event_type == EventType.CLIENT_REGISTER_RECEIVED or event_type == EventType.CLIENT_HEARTBEAT_RECEIVED: # Server side self._add_client_token(fl_ctx) - elif event_type == EventType.AUTHORIZE_COMMAND_CHECK: - command_to_check = fl_ctx.get_prop(key=FLContextKey.COMMAND_NAME) - self.logger.debug(f"Received {command_to_check=}") - if command_to_check == AdminCommandNames.CHECK_RESOURCES: - try: - err = self._client_to_check_participant_token(fl_ctx) - except: - self.log_exception(fl_ctx, "exception in validating participants") - err = "Participants unable to meet client CC requirements" - finally: - if err: - self._not_authorize_job(err, fl_ctx) - elif event_type == EventType.BEFORE_CHECK_CLIENT_RESOURCES: + elif event_type == EventType.CLIENT_QUIT: # Server side + self._remove_client_token(fl_ctx) + elif event_type == EventType.BEFORE_CHECK_RESOURCE_MANAGER: + # Client side: check resources before job scheduled + try: + err = self._client_to_check_participant_token(fl_ctx) + except: + self.log_exception(fl_ctx, "exception in validating participants") + err = "Participants unable to meet client CC requirements" + finally: + if err: + self._block_job(err, fl_ctx) + elif event_type == EventType.BEFORE_CHECK_CLIENT_RESOURCES: + # Server side: job scheduler check client resources try: err = self._server_to_check_client_token(fl_ctx) except: @@ -134,35 +130,126 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): err = "Clients unable to meet server CC requirements" finally: if err: - self._block_job(err, fl_ctx) + if self.critical_level == SHUTDOWN_JOB: + self._block_job(err, fl_ctx) + else: + threading.Thread(target=self._shutdown_system, args=[err, fl_ctx]).start() elif event_type == EventType.AFTER_CHECK_CLIENT_RESOURCES: - # Server side - fl_ctx.remove_prop(PEER_CTX_CC_TOKEN) + client_resource_result = fl_ctx.get_prop(FLContextKey.RESOURCE_CHECK_RESULT) + if client_resource_result: + for site_name, check_result in client_resource_result.items(): + is_resource_enough, reason = check_result + if ( + not is_resource_enough + and reason.startswith(CC_VERIFY_ERROR) + and self.critical_level == SHUTDOWN_SYSTEM + ): + threading.Thread(target=self._shutdown_system, args=[reason, fl_ctx]).start() + break + + def _setup_cc_authorizers(self, fl_ctx): + engine = fl_ctx.get_engine() + for conf in self.cc_issuers_conf: + issuer_id = conf.get(CC_ISSUER_ID) + expiration = conf.get(TOKEN_EXPIRATION) + issuer = engine.get_component(issuer_id) + if not isinstance(issuer, CCAuthorizer): + raise RuntimeError(f"cc_issuer_id {issuer_id} must be a CCAuthorizer, but got {issuer.__class__}") + self.cc_issuers[issuer] = expiration + + for v_id in self.cc_verifier_ids: + verifier = engine.get_component(v_id) + if not isinstance(verifier, CCAuthorizer): + raise RuntimeError(f"cc_authorizer_id {v_id} must be a CCAuthorizer, but got {verifier.__class__}") + namespace = verifier.get_namespace() + if namespace in self.cc_verifiers.keys(): + raise RuntimeError(f"Authorizer with namespace: {namespace} already exist.") + self.cc_verifiers[namespace] = verifier - def _prepare_token_for_login(self, fl_ctx: FLContext): - # client side - if self.my_token is None: - self.my_token = self.helper.get_token() - cc_info = {CC_TOKEN: self.my_token} - fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) + def _prepare_cc_info(self, fl_ctx: FLContext): + # client side: if token expired then generate a new one + self._handle_expired_tokens() + + if not self.token_submitted: + site_cc_info = self.participant_cc_info[self.site_name] + cc_info = self._get_participant_tokens(site_cc_info) + fl_ctx.set_prop(key=CC_INFO, value=cc_info, sticky=False, private=False) + self.logger.info("Sent the CC-tokens to server.") + self.token_submitted = True def _add_client_token(self, fl_ctx: FLContext): # server side peer_ctx = fl_ctx.get_peer_context() token_owner = peer_ctx.get_identity_name() peer_cc_info = peer_ctx.get_prop(CC_INFO) - self.participant_cc_info[token_owner] = peer_cc_info - self.participant_cc_info[token_owner][CC_TOKEN_VALIDATED] = False - def _prepare_for_attestation(self, fl_ctx: FLContext) -> str: + if peer_cc_info: + self.participant_cc_info[token_owner] = peer_cc_info + self.logger.info(f"Added CC client: {token_owner} tokens: {peer_cc_info}") + + if not self.verify_time or time.time() - self.verify_time > self.verify_frequency: + self._verify_running_jobs(fl_ctx) + + def _verify_running_jobs(self, fl_ctx): + engine = fl_ctx.get_engine() + run_processes = engine.run_processes + running_jobs = list(run_processes.keys()) + with self.lock: + for job_id in running_jobs: + job_participants = run_processes[job_id].get(RunProcessKey.PARTICIPANTS) + participants = [] + for _, client in job_participants.items(): + participants.append(client.name) + + err, participant_tokens = self._verify_participants(participants) + if err: + if self.critical_level == SHUTDOWN_JOB: + # maybe shutdown the whole system here. leave the user to define the action + engine.job_runner.stop_run(job_id, fl_ctx) + self.logger.info(f"Stop Job: {job_id} with CC verification error: {err} ") + else: + threading.Thread(target=self._shutdown_system, args=[err, fl_ctx]).start() + + self.verify_time = time.time() + + def _remove_client_token(self, fl_ctx: FLContext): + # server side + peer_ctx = fl_ctx.get_peer_context() + token_owner = peer_ctx.get_identity_name() + if token_owner in self.participant_cc_info.keys(): + self.participant_cc_info.pop(token_owner) + self.logger.info(f"Removed CC client: {token_owner}") + + def _generate_tokens(self, fl_ctx: FLContext) -> str: # both server and client sides self.site_name = fl_ctx.get_identity_name() - self.helper = CCHelper(site_name=self.site_name, verifiers=self.verifiers) - ok = self.helper.prepare() - if not ok: - return "failed to attest" - self.my_token = self.helper.get_token() - self.participant_cc_info[self.site_name] = {CC_TOKEN: self.my_token, CC_TOKEN_VALIDATED: True} + workspace_folder = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT).get_site_config_dir() + + self.participant_cc_info[self.site_name] = [] + for issuer, expiration in self.cc_issuers.items(): + try: + my_token = issuer.generate() + namespace = issuer.get_namespace() + + if not isinstance(expiration, int): + raise ValueError(f"token_expiration value must be int, but got {expiration.__class__}") + if not my_token: + return f"{issuer} failed to get CC token" + + self.logger.info(f"site: {self.site_name} namespace: {namespace} got the token: {my_token}") + cc_info = { + CC_TOKEN: my_token, + CC_ISSUER: issuer, + CC_NAMESPACE: namespace, + TOKEN_GENERATION_TIME: time.time(), + TOKEN_EXPIRATION: int(expiration), + CC_TOKEN_VALIDATED: True, + } + self.participant_cc_info[self.site_name].append(cc_info) + self.token_submitted = False + except CCTokenGenerateError: + raise RuntimeError(f"{issuer} failed to generate CC token.") + return "" def _client_to_check_participant_token(self, fl_ctx: FLContext) -> str: @@ -192,47 +279,107 @@ def _server_to_check_client_token(self, fl_ctx: FLContext) -> str: if not isinstance(participants, list): return f"bad value for {FLContextKey.JOB_PARTICIPANTS} in fl_ctx: expect list bot got {type(participants)}" - participant_tokens = {self.site_name: self.my_token} + err, participant_tokens = self._verify_participants(participants) + if err: + return err + + fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=False, private=False) + self.logger.info(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") + return "" + + def _verify_participants(self, participants): + # if server token expired, then generates a new one + self._handle_expired_tokens() + + participant_tokens = {} + site_cc_info = self.participant_cc_info[self.site_name] + participant_tokens[self.site_name] = self._get_participant_tokens(site_cc_info) + for p in participants: assert isinstance(p, str) if p == self.site_name: continue - if p not in self.participant_cc_info: - return f"no token available for participant {p}" - participant_tokens[p] = self.participant_cc_info[p][CC_TOKEN] + # if p not in self.participant_cc_info: + # return f"no token available for participant {p}" + if self.participant_cc_info.get(p): + participant_tokens[p] = self._get_participant_tokens(self.participant_cc_info[p]) + else: + participant_tokens[p] = [{CC_TOKEN: "", CC_NAMESPACE: ""}] + return self._validate_participants_tokens(participant_tokens), participant_tokens - err = self._validate_participants_tokens(participant_tokens) - if err: - return err + def _get_participant_tokens(self, site_cc_info): + cc_info = [] + for i in site_cc_info: + namespace = i.get(CC_NAMESPACE) + token = i.get(CC_TOKEN) + cc_info.append({CC_TOKEN: token, CC_NAMESPACE: namespace, CC_TOKEN_VALIDATED: False}) + return cc_info - for p in participant_tokens: - self.participant_cc_info[p][CC_TOKEN_VALIDATED] = True - fl_ctx.set_prop(key=PEER_CTX_CC_TOKEN, value=participant_tokens, sticky=True, private=False) - self.logger.debug(f"{self.site_name=} set PEER_CTX_CC_TOKEN with {participant_tokens=}") - return "" + def _handle_expired_tokens(self): + site_cc_info = self.participant_cc_info[self.site_name] + for i in site_cc_info: + issuer = i.get(CC_ISSUER) + token_generate_time = i.get(TOKEN_GENERATION_TIME) + expiration = i.get(TOKEN_EXPIRATION) + if time.time() - token_generate_time > expiration: + token = issuer.generate() + i[CC_TOKEN] = token + i[TOKEN_GENERATION_TIME] = time.time() + self.logger.info( + f"site: {self.site_name} namespace: {issuer.get_namespace()} got a new CC token: {token}" + ) + + self.token_submitted = False def _validate_participants_tokens(self, participants) -> str: self.logger.debug(f"Validating participant tokens {participants=}") - result = self.helper.validate_participants(participants) - assert isinstance(result, dict) - for p in result: - self.participant_cc_info[p] = {CC_TOKEN: participants[p], CC_TOKEN_VALIDATED: True} - invalid_participant_list = [k for k, v in self.participant_cc_info.items() if v[CC_TOKEN_VALIDATED] is False] + result, invalid_participant_list = self._validate_participants(participants) if invalid_participant_list: invalid_participant_string = ",".join(invalid_participant_list) self.logger.debug(f"{invalid_participant_list=}") - return f"Participant {invalid_participant_string} not meeting CC requirements" + return f"Participant {invalid_participant_string}" + CC_VERIFICATION_FAILED else: return "" - def _not_authorize_job(self, reason: str, fl_ctx: FLContext): - job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") - self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_REASON, value=reason) - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False) + def _validate_participants(self, participants: Dict[str, List[Dict[str, str]]]) -> (Dict[str, bool], List[str]): + result = {} + invalid_participant_list = [] + if not participants: + return result, invalid_participant_list + for k, cc_info in participants.items(): + for v in cc_info: + token = v.get(CC_TOKEN, "") + namespace = v.get(CC_NAMESPACE, "") + verifier = self.cc_verifiers.get(namespace, None) + try: + if verifier and verifier.verify(token): + result[k + "." + namespace] = True + else: + invalid_participant_list.append(k + " namespace: {" + namespace + "}") + except CCTokenVerifyError: + invalid_participant_list.append(k + " namespace: {" + namespace + "}") + self.logger.info(f"CC - results from validating participants' tokens: {result}") + return result, invalid_participant_list def _block_job(self, reason: str, fl_ctx: FLContext): job_id = fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID, "") self.log_error(fl_ctx, f"Job {job_id} is blocked: {reason}") - fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=reason) - fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False) + fl_ctx.set_prop(key=FLContextKey.JOB_BLOCK_REASON, value=CC_VERIFY_ERROR + reason, sticky=False) + fl_ctx.set_prop(key=FLContextKey.AUTHORIZATION_RESULT, value=False, sticky=False) + + def _shutdown_system(self, reason: str, fl_ctx: FLContext): + engine = fl_ctx.get_engine() + run_processes = engine.run_processes + running_jobs = list(run_processes.keys()) + for job_id in running_jobs: + engine.job_runner.stop_run(job_id, fl_ctx) + + conn = Connection({}, engine.server.admin_server) + conn.app_ctx = engine + + cmd = TrainingCommandModule() + args = ["shutdown", "all"] + cmd.validate_command_targets(conn, args[1:]) + cmd.shutdown(conn, args) + + self.logger.error(f"CC system shutdown! due to reason: {reason}") diff --git a/nvflare/app_opt/confidential_computing/gpu_authorizer.py b/nvflare/app_opt/confidential_computing/gpu_authorizer.py new file mode 100644 index 0000000000..bd55e2a463 --- /dev/null +++ b/nvflare/app_opt/confidential_computing/gpu_authorizer.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer + +GPU_NAMESPACE = "x-nv-gpu-" + + +class GPUAuthorizer(CCAuthorizer): + """Note: This is just a fake implementation for GPU authorizer. It will be replaced later + with the real implementation. + + """ + + def __init__(self, verifiers: list) -> None: + """ + + Args: + verifiers (list): + each element in this list is a dictionary and the keys of dictionary are + "devices", "env", "url", "appraisal_policy_file" and "result_policy_file." + + the values of devices are "gpu" and "cpu" + the values of env are "local" and "test" + currently, valid combination is gpu + local + + url must be an empty string + appraisal_policy_file must point to an existing file + currently supports an empty file only + + result_policy_file must point to an existing file + currently supports the following content only + + .. code-block:: json + + { + "version":"1.0", + "authorization-rules":{ + "x-nv-gpu-available":true, + "x-nv-gpu-attestation-report-available":true, + "x-nv-gpu-info-fetched":true, + "x-nv-gpu-arch-check":true, + "x-nv-gpu-root-cert-available":true, + "x-nv-gpu-cert-chain-verified":true, + "x-nv-gpu-ocsp-cert-chain-verified":true, + "x-nv-gpu-ocsp-signature-verified":true, + "x-nv-gpu-cert-ocsp-nonce-match":true, + "x-nv-gpu-cert-check-complete":true, + "x-nv-gpu-measurement-available":true, + "x-nv-gpu-attestation-report-parsed":true, + "x-nv-gpu-nonce-match":true, + "x-nv-gpu-attestation-report-driver-version-match":true, + "x-nv-gpu-attestation-report-vbios-version-match":true, + "x-nv-gpu-attestation-report-verified":true, + "x-nv-gpu-driver-rim-schema-fetched":true, + "x-nv-gpu-driver-rim-schema-validated":true, + "x-nv-gpu-driver-rim-cert-extracted":true, + "x-nv-gpu-driver-rim-signature-verified":true, + "x-nv-gpu-driver-rim-driver-measurements-available":true, + "x-nv-gpu-driver-vbios-rim-fetched":true, + "x-nv-gpu-vbios-rim-schema-validated":true, + "x-nv-gpu-vbios-rim-cert-extracted":true, + "x-nv-gpu-vbios-rim-signature-verified":true, + "x-nv-gpu-vbios-rim-driver-measurements-available":true, + "x-nv-gpu-vbios-index-conflict":true, + "x-nv-gpu-measurements-match":true + } + } + + """ + super().__init__() + self.verifiers = verifiers + + def get_namespace(self) -> str: + return GPU_NAMESPACE + + def generate(self) -> str: + raise NotImplementedError + + def verify(self, token: str) -> bool: + raise NotImplementedError diff --git a/nvflare/app_opt/confidential_computing/tdx_authorizer.py b/nvflare/app_opt/confidential_computing/tdx_authorizer.py new file mode 100644 index 0000000000..21bff9035e --- /dev/null +++ b/nvflare/app_opt/confidential_computing/tdx_authorizer.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess + +from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer + +TDX_NAMESPACE = "tdx_" +TDX_CLI_CONFIG = "config.json" +TOKEN_FILE = "token.txt" +VERIFY_FILE = "verify.txt" +ERROR_FILE = "error.txt" + + +class TDXAuthorizer(CCAuthorizer): + def __init__(self, tdx_cli_command: str, config_dir: str) -> None: + super().__init__() + self.tdx_cli_command = tdx_cli_command + self.config_dir = config_dir + + self.config_file = os.path.join(self.config_dir, TDX_CLI_CONFIG) + + def generate(self) -> str: + token_file = os.path.join(self.config_dir, TOKEN_FILE) + out = open(token_file, "w") + error_file = os.path.join(self.config_dir, ERROR_FILE) + err_out = open(error_file, "w") + + command = ["sudo", self.tdx_cli_command, "-c", self.config_file, "token", "--no-eventlog"] + subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) + + if not os.path.exists(error_file) or not os.path.exists(token_file): + return "" + + try: + with open(error_file, "r") as e_f: + if "Error:" in e_f.read(): + return "" + else: + with open(token_file, "r") as t_f: + token = t_f.readline() + return token + except: + return "" + + def verify(self, token: str) -> bool: + out = open(os.path.join(self.config_dir, VERIFY_FILE), "w") + error_file = os.path.join(self.config_dir, ERROR_FILE) + err_out = open(error_file, "w") + + command = [self.tdx_cli_command, "verify", "--config", self.config_file, "--token", token] + subprocess.run(command, preexec_fn=os.setsid, stdout=out, stderr=err_out) + + if not os.path.exists(error_file): + return False + + try: + with open(error_file, "r") as f: + if "Error:" in f.read(): + return False + except: + return False + + return True + + def get_namespace(self) -> str: + return TDX_NAMESPACE diff --git a/nvflare/private/fed/app/client/client_train.py b/nvflare/private/fed/app/client/client_train.py index f933c14349..2618974a81 100644 --- a/nvflare/private/fed/app/client/client_train.py +++ b/nvflare/private/fed/app/client/client_train.py @@ -27,7 +27,7 @@ from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.private.defs import AppFolderConstants from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger, create_privacy_manager -from nvflare.private.fed.app.utils import version_check +from nvflare.private.fed.app.utils import component_security_check, version_check from nvflare.private.fed.client.admin import FedAdminAgent from nvflare.private.fed.client.client_engine import ClientEngine from nvflare.private.fed.client.client_status import ClientStatus @@ -108,8 +108,11 @@ def main(args): time.sleep(1.0) with client_engine.new_context() as fl_ctx: + fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) client_engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) + component_security_check(fl_ctx) + client_engine.fire_event(EventType.BEFORE_CLIENT_REGISTER, fl_ctx) federated_client.register(fl_ctx) fl_ctx.set_prop(FLContextKey.CLIENT_TOKEN, federated_client.token) diff --git a/nvflare/private/fed/app/deployer/server_deployer.py b/nvflare/private/fed/app/deployer/server_deployer.py index e800820497..ecb2a1b04f 100644 --- a/nvflare/private/fed/app/deployer/server_deployer.py +++ b/nvflare/private/fed/app/deployer/server_deployer.py @@ -16,8 +16,9 @@ import threading from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import SystemComponents +from nvflare.apis.fl_constant import FLContextKey, SystemComponents from nvflare.apis.workspace import Workspace +from nvflare.private.fed.app.utils import component_security_check from nvflare.private.fed.server.fed_server import FederatedServer from nvflare.private.fed.server.job_runner import JobRunner from nvflare.private.fed.server.run_manager import RunManager @@ -119,8 +120,11 @@ def deploy(self, args): run_manager.add_component(SystemComponents.JOB_RUNNER, job_runner) with services.engine.new_context() as fl_ctx: + fl_ctx.set_prop(FLContextKey.WORKSPACE_OBJECT, workspace, private=True) services.engine.fire_event(EventType.SYSTEM_BOOTSTRAP, fl_ctx) + component_security_check(fl_ctx) + threading.Thread(target=self._start_job_runner, args=[job_runner, fl_ctx]).start() services.status = ServerStatus.STARTED diff --git a/nvflare/private/fed/app/utils.py b/nvflare/private/fed/app/utils.py index 8552ec4945..94712d4b21 100644 --- a/nvflare/private/fed/app/utils.py +++ b/nvflare/private/fed/app/utils.py @@ -20,6 +20,9 @@ import psutil +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.fl_exception import UnsafeComponentError from nvflare.fuel.hci.security import hash_password from nvflare.private.defs import SSLConstants from nvflare.private.fed.runner import Runner @@ -98,3 +101,12 @@ def version_check(): raise RuntimeError("Python versions 3.11 and above are not yet supported. Please use Python 3.8, 3.9 or 3.10.") if sys.version_info < (3, 8): raise RuntimeError("Python versions 3.7 and below are not supported. Please use Python 3.8, 3.9 or 3.10") + + +def component_security_check(fl_ctx: FLContext): + exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) + if exceptions: + for _, exception in exceptions.items(): + if isinstance(exception, UnsafeComponentError): + print(f"Unsafe component configured, could not start {fl_ctx.get_identity_name()}!!") + raise RuntimeError(exception) diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index d3475f0253..e95e2f5269 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -17,6 +17,7 @@ import time from typing import List, Optional +from nvflare.apis.event_type import EventType from nvflare.apis.filter import Filter from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_constant import ReturnCode as ShareableRC @@ -292,6 +293,9 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): server's reply to the last message """ + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + shareable = Shareable() + shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) client_name = fl_ctx.get_identity_name() quit_message = new_cell_message( { @@ -299,7 +303,8 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.SSID: ssid, CellMessageHeaderKeys.PROJECT_NAME: task_name, - } + }, + shareable, ) try: result = self.cell.send_request( @@ -328,6 +333,11 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C heartbeats_log_interval = 10 while not self.heartbeat_done: try: + engine.fire_event(EventType.BEFORE_CLIENT_HEARTBEAT, fl_ctx) + shareable = Shareable() + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) + job_ids = engine.get_all_job_ids() heartbeat_message = new_cell_message( { @@ -336,7 +346,8 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C CellMessageHeaderKeys.CLIENT_NAME: client_name, CellMessageHeaderKeys.PROJECT_NAME: task_name, CellMessageHeaderKeys.JOB_IDS: job_ids, - } + }, + shareable, ) try: @@ -367,6 +378,7 @@ def send_heartbeat(self, servers, task_name, token, ssid, client_name, engine: C except Exception as ex: raise FLCommunicationError("error:client_quit", ex) + engine.fire_event(EventType.AFTER_CLIENT_HEARTBEAT, fl_ctx) for i in range(wait_times): time.sleep(2) if self.heartbeat_done: diff --git a/nvflare/private/fed/client/scheduler_cmds.py b/nvflare/private/fed/client/scheduler_cmds.py index 58828121bb..7ea2e8d55b 100644 --- a/nvflare/private/fed/client/scheduler_cmds.py +++ b/nvflare/private/fed/client/scheduler_cmds.py @@ -16,7 +16,7 @@ from typing import List from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLContextKey, ReturnCode, SystemComponents +from nvflare.apis.fl_constant import FLContextKey, ReturnCode, ServerCommandKey, SystemComponents from nvflare.apis.resource_manager_spec import ResourceConsumerSpec, ResourceManagerSpec from nvflare.apis.shareable import Shareable from nvflare.private.admin_defs import Message @@ -68,6 +68,8 @@ def process(self, req: Message, app_ctx) -> Message: fl_ctx.set_prop(key=FLContextKey.CLIENT_RESOURCE_SPECS, value=resource_spec, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.CURRENT_JOB_ID, job_id, private=True, sticky=False) + shared_fl_ctx = req.get_header(ServerCommandKey.PEER_FL_CONTEXT) + fl_ctx.set_peer_context(shared_fl_ctx) engine.fire_event(EventType.BEFORE_CHECK_RESOURCE_MANAGER, fl_ctx) block_reason = fl_ctx.get_prop(FLContextKey.JOB_BLOCK_REASON) @@ -78,7 +80,7 @@ def process(self, req: Message, app_ctx) -> Message: is_resource_enough, token = resource_manager.check_resources( resource_requirement=resource_spec, fl_ctx=fl_ctx ) - except Exception: + except Exception as e: result.set_return_code(ReturnCode.EXECUTION_EXCEPTION) result.set_header(ShareableHeader.IS_RESOURCE_ENOUGH, is_resource_enough) diff --git a/nvflare/private/fed/server/admin.py b/nvflare/private/fed/server/admin.py index eec732013a..71fc765939 100644 --- a/nvflare/private/fed/server/admin.py +++ b/nvflare/private/fed/server/admin.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import threading import time from typing import List, Optional from nvflare.apis.event_type import EventType -from nvflare.apis.shareable import ReservedHeaderKey +from nvflare.apis.fl_constant import ServerCommandKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.cellnet.net_manager import NetManager @@ -229,11 +230,12 @@ def send_request_to_client(self, req: Message, client_token: str, timeout_secs=2 if not isinstance(req, Message): raise TypeError("request must be Message but got {}".format(type(req))) reqs = {client_token: req} - replies = self.send_requests(reqs, timeout_secs=timeout_secs) - if replies is None or len(replies) <= 0: - return None - else: - return replies[0] + with self.sai.new_context() as fl_ctx: + replies = self.send_requests(reqs, fl_ctx, timeout_secs=timeout_secs) + if replies is None or len(replies) <= 0: + return None + else: + return replies[0] def send_requests_and_get_reply_dict(self, requests: dict, timeout_secs=2.0) -> dict: """Send requests to clients @@ -250,12 +252,13 @@ def send_requests_and_get_reply_dict(self, requests: dict, timeout_secs=2.0) -> for token, _ in requests.items(): result[token] = None - replies = self.send_requests(requests, timeout_secs=timeout_secs) - for r in replies: - result[r.client_token] = r.reply + with self.sai.new_context() as fl_ctx: + replies = self.send_requests(requests, fl_ctx, timeout_secs=timeout_secs) + for r in replies: + result[r.client_token] = r.reply return result - def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [ClientReply]: + def send_requests(self, requests: dict, fl_ctx: FLContext, timeout_secs=2.0, optional=False) -> [ClientReply]: """Send requests to clients. NOTE:: @@ -266,6 +269,7 @@ def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [Cl Args: requests: A dict of requests: {client token: request or list of requests} + fl_ctx: FLContext timeout_secs: how long to wait for reply before timeout optional: whether the requests are optional @@ -274,9 +278,10 @@ def send_requests(self, requests: dict, timeout_secs=2.0, optional=False) -> [Cl """ for _, request in requests.items(): - with self.sai.new_context() as fl_ctx: - self.sai.fire_event(EventType.BEFORE_SEND_ADMIN_COMMAND, fl_ctx) - request.set_header(ReservedHeaderKey.PEER_PROPS, copy.deepcopy(fl_ctx.get_all_public_props())) + # with self.sai.new_context() as fl_ctx: + self.sai.fire_event(EventType.BEFORE_SEND_ADMIN_COMMAND, fl_ctx) + shared_fl_ctx = gen_new_peer_ctx(fl_ctx) + request.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) return send_requests( cell=self.cell, diff --git a/nvflare/private/fed/server/cmd_utils.py b/nvflare/private/fed/server/cmd_utils.py index d2d7bd6bfd..4cc8cc72ce 100644 --- a/nvflare/private/fed/server/cmd_utils.py +++ b/nvflare/private/fed/server/cmd_utils.py @@ -148,7 +148,8 @@ def send_request_to_clients(self, conn, message): cmd_timeout = conn.get_prop(ConnProps.CMD_TIMEOUT) if not cmd_timeout: cmd_timeout = admin_server.timeout - replies = admin_server.send_requests(requests, timeout_secs=cmd_timeout) + with admin_server.sai.new_context() as fl_ctx: + replies = admin_server.send_requests(requests, fl_ctx, timeout_secs=cmd_timeout) return replies diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index d07dc99246..ae50e4c3a6 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -493,7 +493,7 @@ def register_client(self, request: Message) -> Message: shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) fl_ctx.set_peer_context(shared_fl_ctx) - self.engine.fire_event(EventType.CLIENT_REGISTERED, fl_ctx=fl_ctx) + self.engine.fire_event(EventType.CLIENT_REGISTER_RECEIVED, fl_ctx=fl_ctx) exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) if exceptions: @@ -513,6 +513,7 @@ def register_client(self, request: Message) -> Message: } else: headers = {} + self.engine.fire_event(EventType.CLIENT_REGISTER_PROCESSED, fl_ctx=fl_ctx) return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx) except NotAuthenticated as e: self.logger.error(f"Failed to authenticate the register_client: {secure_format_exception(e)}") @@ -539,6 +540,11 @@ def quit_client(self, request: Message) -> Message: token = client.get_token() self.logout_client(token) + data = request.payload + shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) + fl_ctx.set_peer_context(shared_fl_ctx) + self.engine.fire_event(EventType.CLIENT_QUIT, fl_ctx=fl_ctx) + headers = {CellMessageHeaderKeys.MESSAGE: "Removed client"} return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx) @@ -572,6 +578,11 @@ def client_heartbeat(self, request: Message) -> Message: if error is not None: return make_cellnet_reply(rc=F3ReturnCode.COMM_ERROR, error=error) + data = request.payload + shared_fl_ctx = data.get_header(ServerCommandKey.PEER_FL_CONTEXT) + fl_ctx.set_peer_context(shared_fl_ctx) + self.engine.fire_event(EventType.CLIENT_HEARTBEAT_RECEIVED, fl_ctx=fl_ctx) + token = request.get_header(CellMessageHeaderKeys.TOKEN) client_name = request.get_header(CellMessageHeaderKeys.CLIENT_NAME) @@ -593,6 +604,7 @@ def client_heartbeat(self, request: Message) -> Message: f"These jobs: {display_runs} are not running on the server. " f"Ask client: {client_name} to abort these runs." ) + self.engine.fire_event(EventType.CLIENT_HEARTBEAT_PROCESSED, fl_ctx=fl_ctx) return reply def _sync_client_jobs(self, request, client_token): diff --git a/nvflare/private/fed/server/job_runner.py b/nvflare/private/fed/server/job_runner.py index ae2cb7d568..15a9bea833 100644 --- a/nvflare/private/fed/server/job_runner.py +++ b/nvflare/private/fed/server/job_runner.py @@ -49,7 +49,8 @@ def _send_to_clients(admin_server, client_sites: List[str], engine, message, tim if timeout is None: timeout = admin_server.timeout - replies = admin_server.send_requests(requests, timeout_secs=timeout, optional=optional) + with admin_server.sai.new_context() as fl_ctx: + replies = admin_server.send_requests(requests, fl_ctx, timeout_secs=timeout, optional=optional) return replies @@ -248,7 +249,7 @@ def _start_run(self, job_id: str, job: Job, client_sites: Dict[str, DispatchInfo if err: raise RuntimeError(f"Could not start the server App for job: {job_id}.") - replies = engine.start_client_job(job_id, client_sites) + replies = engine.start_client_job(job_id, client_sites, fl_ctx) client_sites_names = list(client_sites.keys()) check_client_replies(replies=replies, client_sites=client_sites_names, command=f"start job ({job_id})") display_sites = ",".join(client_sites_names) diff --git a/nvflare/private/fed/server/server_engine.py b/nvflare/private/fed/server/server_engine.py index d6b636b575..ed52b4e95f 100644 --- a/nvflare/private/fed/server/server_engine.py +++ b/nvflare/private/fed/server/server_engine.py @@ -721,8 +721,8 @@ def reset_errors(self, job_id) -> str: return f"reset the server error stats for job: {job_id}" - def _send_admin_requests(self, requests, timeout_secs=10) -> List[ClientReply]: - return self.server.admin_server.send_requests(requests, timeout_secs=timeout_secs) + def _send_admin_requests(self, requests, fl_ctx: FLContext, timeout_secs=10) -> List[ClientReply]: + return self.server.admin_server.send_requests(requests, fl_ctx, timeout_secs=timeout_secs) def check_client_resources(self, job: Job, resource_reqs, fl_ctx: FLContext) -> Dict[str, Tuple[bool, str]]: requests = {} @@ -737,7 +737,7 @@ def check_client_resources(self, job: Job, resource_reqs, fl_ctx: FLContext) -> requests.update({client.token: request}) replies = [] if requests: - replies = self._send_admin_requests(requests, 15) + replies = self._send_admin_requests(requests, fl_ctx, 15) result = {} for r in replies: site_name = r.client_name @@ -765,7 +765,7 @@ def _make_message_for_check_resource(self, job, resource_requirements, fl_ctx): return request def cancel_client_resources( - self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] + self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext ): requests = {} for site_name, result in resource_check_results.items(): @@ -778,9 +778,9 @@ def cancel_client_resources( if client: requests.update({client.token: request}) if requests: - _ = self._send_admin_requests(requests) + _ = self._send_admin_requests(requests, fl_ctx) - def start_client_job(self, job_id, client_sites): + def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): requests = {} for site, dispatch_info in client_sites.items(): resource_requirement = dispatch_info.resource_requirements @@ -793,7 +793,7 @@ def start_client_job(self, job_id, client_sites): requests.update({client.token: request}) replies = [] if requests: - replies = self._send_admin_requests(requests, timeout_secs=20) + replies = self._send_admin_requests(requests, fl_ctx, timeout_secs=20) return replies def stop_all_jobs(self): diff --git a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py index df6c63a912..1f7f9ab48e 100644 --- a/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py +++ b/tests/unit_test/app_common/job_schedulers/job_scheduler_test.py @@ -120,7 +120,7 @@ def persist_components(self, fl_ctx: FLContext, completed: bool): def restore_components(self, snapshot, fl_ctx: FLContext): pass - def start_client_job(self, job_id, client_sites): + def start_client_job(self, job_id, client_sites, fl_ctx: FLContext): pass def check_client_resources( @@ -136,15 +136,15 @@ def get_client_name_from_token(self, token): return self.clients.get(token) def cancel_client_resources( - self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict] + self, resource_check_results: Dict[str, Tuple[bool, str]], resource_reqs: Dict[str, dict], fl_ctx: FLContext ): - with self.new_context() as fl_ctx: - for site_name, result in resource_check_results.items(): - check_result, token = result - if check_result and token: - self.clients[site_name].resource_manager.cancel_resources( - resource_requirement=resource_reqs[site_name], token=token, fl_ctx=fl_ctx - ) + # with self.new_context() as fl_ctx: + for site_name, result in resource_check_results.items(): + check_result, token = result + if check_result and token: + self.clients[site_name].resource_manager.cancel_resources( + resource_requirement=resource_reqs[site_name], token=token, fl_ctx=fl_ctx + ) def update_job_run_status(self): pass diff --git a/tests/unit_test/app_opt/confidential_computing/cc_manager_test.py b/tests/unit_test/app_opt/confidential_computing/cc_manager_test.py new file mode 100644 index 0000000000..3db086a277 --- /dev/null +++ b/tests/unit_test/app_opt/confidential_computing/cc_manager_test.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +from nvflare.apis.fl_constant import ReservedKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.server_engine_spec import ServerEngineSpec +from nvflare.app_opt.confidential_computing.cc_manager import ( + CC_INFO, + CC_NAMESPACE, + CC_TOKEN, + CC_TOKEN_VALIDATED, + CC_VERIFICATION_FAILED, + CCManager, +) +from nvflare.app_opt.confidential_computing.tdx_authorizer import TDX_NAMESPACE, TDXAuthorizer + +VALID_TOKEN = "valid_token" +INVALID_TOKEN = "invalid_token" + + +class TestCCManager: + def setup_method(self, method): + issues_conf = [{"issuer_id": "tdx_authorizer", "token_expiration": 250}] + + verify_ids = (["tdx_authorizer"],) + self.cc_manager = CCManager(issues_conf, verify_ids) + + def test_authorizer_setup(self): + + fl_ctx, tdx_authorizer = self._setup_authorizers() + + assert self.cc_manager.cc_issuers == {tdx_authorizer: 250} + assert self.cc_manager.cc_verifiers == {TDX_NAMESPACE: tdx_authorizer} + + def _setup_authorizers(self): + fl_ctx = Mock(spec=FLContext) + fl_ctx.get_identity_name.return_value = "server" + engine = Mock(spec=ServerEngineSpec) + fl_ctx.get_engine.return_value = engine + + tdx_authorizer = Mock(spec=TDXAuthorizer) + tdx_authorizer.get_namespace.return_value = TDX_NAMESPACE + tdx_authorizer.verify = self._verify_token + engine.get_component.return_value = tdx_authorizer + self.cc_manager._setup_cc_authorizers(fl_ctx) + + tdx_authorizer.generate.return_value = VALID_TOKEN + self.cc_manager._generate_tokens(fl_ctx) + + return fl_ctx, tdx_authorizer + + def _verify_token(self, token): + if token == VALID_TOKEN: + return True + else: + return False + + def test_add_client_token(self): + + cc_info1, cc_info2 = self._add_failed_tokens() + + assert self.cc_manager.participant_cc_info["client1"] == cc_info1 + assert self.cc_manager.participant_cc_info["client2"] == cc_info2 + + def _add_failed_tokens(self): + self.cc_manager._verify_running_jobs = Mock() + client_name = "client1" + valid_token = VALID_TOKEN + cc_info1, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + + client_name = "client2" + valid_token = INVALID_TOKEN + cc_info2, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + return cc_info1, cc_info2 + + def test_verification_success(self): + + self._setup_authorizers() + + self.cc_manager._verify_running_jobs = Mock() + + self.cc_manager._verify_running_jobs = Mock() + client_name = "client1" + valid_token = VALID_TOKEN + cc_info1, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + + client_name = "client2" + valid_token = VALID_TOKEN + cc_info2, fl_ctx = self._add_client_token(client_name, valid_token) + self.cc_manager._add_client_token(fl_ctx) + + self.cc_manager._handle_expired_tokens = Mock() + + err, participant_tokens = self.cc_manager._verify_participants(["client1", "client2"]) + + assert not err + + def test_verification_failed(self): + + self._setup_authorizers() + + self.cc_manager._verify_running_jobs = Mock() + self._add_failed_tokens() + self.cc_manager._handle_expired_tokens = Mock() + + err, participant_tokens = self.cc_manager._verify_participants(["client1", "client2"]) + + assert "client2" in err + assert CC_VERIFICATION_FAILED in err + + def _add_client_token(self, client_name, valid_token): + peer_ctx = FLContext() + cc_info = [{CC_TOKEN: valid_token, CC_NAMESPACE: TDX_NAMESPACE, CC_TOKEN_VALIDATED: False}] + peer_ctx.set_prop(CC_INFO, cc_info) + peer_ctx.set_prop(ReservedKey.IDENTITY_NAME, client_name) + fl_ctx = Mock(spec=FLContext) + fl_ctx.get_peer_context.return_value = peer_ctx + return cc_info, fl_ctx diff --git a/tests/unit_test/private/fed/server/fed_server_test.py b/tests/unit_test/private/fed/server/fed_server_test.py index abf8a21974..235cfac0c9 100644 --- a/tests/unit_test/private/fed/server/fed_server_test.py +++ b/tests/unit_test/private/fed/server/fed_server_test.py @@ -16,6 +16,7 @@ import pytest +from nvflare.apis.shareable import Shareable from nvflare.private.defs import CellMessageHeaderKeys, new_cell_message from nvflare.private.fed.server.fed_server import FederatedServer from nvflare.private.fed.server.server_state import ColdState, HotState @@ -46,7 +47,8 @@ def test_heart_beat_abort_jobs(self, server_state, expected): CellMessageHeaderKeys.CLIENT_NAME: "client_name", CellMessageHeaderKeys.PROJECT_NAME: "task_name", CellMessageHeaderKeys.JOB_IDS: ["extra_job"], - } + }, + Shareable(), ) result = server.client_heartbeat(request)