diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index f85cac0d09..e95c23242c 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -52,7 +52,7 @@ from nvflare.fuel.utils.zip_utils import split_path, unzip_all_from_bytes, zip_directory_to_bytes from nvflare.private.defs import AppFolderConstants from nvflare.private.fed.app.deployer.simulator_deployer import SimulatorDeployer -from nvflare.private.fed.app.utils import kill_child_processes +from nvflare.private.fed.app.utils import init_security_content_service, kill_child_processes from nvflare.private.fed.client.client_status import ClientStatus from nvflare.private.fed.server.job_meta_validator import JobMetaValidator from nvflare.private.fed.simulator.simulator_app_runner import SimulatorServerAppRunner @@ -153,6 +153,8 @@ def setup(self): AuthorizationService.initialize(EmptyAuthorizer()) AuditService.the_auditor = SimulatorAuditor() + init_security_content_service(self.args.workspace) + self.simulator_root = os.path.join(self.args.workspace, SimulatorConstants.JOB_NAME) if os.path.exists(self.simulator_root): shutil.rmtree(self.simulator_root) diff --git a/nvflare/private/fed/app/simulator/simulator_worker.py b/nvflare/private/fed/app/simulator/simulator_worker.py index 7fd94f441b..1fd9ba9614 100644 --- a/nvflare/private/fed/app/simulator/simulator_worker.py +++ b/nvflare/private/fed/app/simulator/simulator_worker.py @@ -30,7 +30,7 @@ from nvflare.fuel.hci.server.authz import AuthorizationService from nvflare.fuel.sec.audit import AuditService from nvflare.private.fed.app.deployer.base_client_deployer import BaseClientDeployer -from nvflare.private.fed.app.utils import check_parent_alive +from nvflare.private.fed.app.utils import check_parent_alive, init_security_content_service from nvflare.private.fed.client.client_engine import ClientEngine from nvflare.private.fed.client.client_status import ClientStatus from nvflare.private.fed.client.fed_client import FederatedClient @@ -241,6 +241,8 @@ def main(args): # AuditService.initialize(audit_file_name=WorkspaceConstants.AUDIT_LOG) AuditService.the_auditor = SimulatorAuditor() + init_security_content_service(args.workspace) + if args.gpu: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu diff --git a/nvflare/private/fed/app/utils.py b/nvflare/private/fed/app/utils.py index 94712d4b21..d219e3c5b0 100644 --- a/nvflare/private/fed/app/utils.py +++ b/nvflare/private/fed/app/utils.py @@ -20,10 +20,12 @@ import psutil -from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_constant import FLContextKey, WorkspaceConstants from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import UnsafeComponentError +from nvflare.apis.workspace import Workspace from nvflare.fuel.hci.security import hash_password +from nvflare.fuel.sec.security_content_service import SecurityContentService from nvflare.private.defs import SSLConstants from nvflare.private.fed.runner import Runner from nvflare.private.fed.server.admin import FedAdminServer @@ -103,6 +105,12 @@ def version_check(): raise RuntimeError("Python versions 3.7 and below are not supported. Please use Python 3.8, 3.9 or 3.10") +def init_security_content_service(workspace_dir): + os.makedirs(os.path.join(workspace_dir, WorkspaceConstants.STARTUP_FOLDER_NAME), exist_ok=True) + workspace_obj = Workspace(root_dir=workspace_dir) + SecurityContentService.initialize(content_folder=workspace_obj.get_startup_kit_dir()) + + def component_security_check(fl_ctx: FLContext): exceptions = fl_ctx.get_prop(FLContextKey.EXCEPTIONS) if exceptions: diff --git a/tests/unit_test/private/fed/app/simulator/simulator_runner_test.py b/tests/unit_test/private/fed/app/simulator/simulator_runner_test.py index 0f9ac29b75..0fcfefe1d3 100644 --- a/tests/unit_test/private/fed/app/simulator/simulator_runner_test.py +++ b/tests/unit_test/private/fed/app/simulator/simulator_runner_test.py @@ -13,11 +13,13 @@ # limitations under the License. import os +import shutil import uuid from unittest.mock import patch import pytest +from nvflare.apis.fl_constant import WorkspaceConstants from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner from nvflare.private.fed.utils.fed_utils import split_gpus @@ -28,14 +30,22 @@ def get_root_url_for_child(self): class TestSimulatorRunner: + def setup_method(self, method): + self.workspace_name = str(uuid.uuid4()) + self.cwd = os.getcwd() + os.makedirs(os.path.join(self.cwd, self.workspace_name, WorkspaceConstants.STARTUP_FOLDER_NAME)) + + def teardown_method(self, method): + os.chdir(self.cwd) + shutil.rmtree(os.path.join(self.cwd, self.workspace_name)) + @patch("nvflare.private.fed.app.deployer.simulator_deployer.SimulatorServer.deploy") @patch("nvflare.private.fed.app.utils.FedAdminServer") @patch("nvflare.private.fed.client.fed_client.FederatedClient.register") @patch("nvflare.private.fed.server.fed_server.BaseServer.get_cell", return_value=MockCell()) def test_valid_job_simulate_setup(self, mock_deploy, mock_admin, mock_register, mock_cell): - workspace_name = str(uuid.uuid4()) job_folder = os.path.join(os.path.dirname(__file__), "../../../../data/jobs/valid_job") - runner = SimulatorRunner(job_folder=job_folder, workspace=workspace_name, threads=1) + runner = SimulatorRunner(job_folder=job_folder, workspace=self.workspace_name, threads=1) assert runner.setup() expected_clients = ["site-1", "site-2"] @@ -49,9 +59,8 @@ def test_valid_job_simulate_setup(self, mock_deploy, mock_admin, mock_register, @patch("nvflare.private.fed.client.fed_client.FederatedClient.register") @patch("nvflare.private.fed.server.fed_server.BaseServer.get_cell", return_value=MockCell()) def test_client_names_setup(self, mock_deploy, mock_admin, mock_register, mock_cell): - workspace_name = str(uuid.uuid4()) job_folder = os.path.join(os.path.dirname(__file__), "../../../../data/jobs/valid_job") - runner = SimulatorRunner(job_folder=job_folder, workspace=workspace_name, clients="site-1", threads=1) + runner = SimulatorRunner(job_folder=job_folder, workspace=self.workspace_name, clients="site-1", threads=1) assert runner.setup() expected_clients = ["site-1"] @@ -65,9 +74,8 @@ def test_client_names_setup(self, mock_deploy, mock_admin, mock_register, mock_c @patch("nvflare.private.fed.client.fed_client.FederatedClient.register") @patch("nvflare.private.fed.server.fed_server.BaseServer.get_cell", return_value=MockCell()) def test_no_app_for_client(self, mock_deploy, mock_admin, mock_register, mock_cell): - workspace_name = str(uuid.uuid4()) job_folder = os.path.join(os.path.dirname(__file__), "../../../../data/jobs/valid_job") - runner = SimulatorRunner(job_folder=job_folder, workspace=workspace_name, n_clients=3, threads=1) + runner = SimulatorRunner(job_folder=job_folder, workspace=self.workspace_name, n_clients=3, threads=1) assert not runner.setup() @pytest.mark.parametrize(