diff --git a/nvflare/app_common/executors/task_script_runner.py b/nvflare/app_common/executors/task_script_runner.py index c7103b661b..fba46f6b05 100644 --- a/nvflare/app_common/executors/task_script_runner.py +++ b/nvflare/app_common/executors/task_script_runner.py @@ -61,14 +61,29 @@ def get_sys_argv(self): return [self.script_path] + args_list def get_script_full_path(self, script_path) -> str: - target_files = None + target_file = None + script_filename = os.path.basename(script_path) + script_dirs = os.path.dirname(script_path) + for r, dirs, files in os.walk(os.getcwd()): - target_files = [os.path.join(r, f) for f in files if f == script_path] - if target_files: + for f in files: + absolute_path = os.path.join(r, f) + if absolute_path.endswith(script_path): + parent_dir = absolute_path[: absolute_path.find(script_path)].rstrip(os.sep) + if os.path.isdir(parent_dir): + target_file = absolute_path + break + + if not script_dirs and f == script_filename: + target_file = absolute_path + break + + if target_file: break - if not target_files: + + if not target_file: raise ValueError(f"{script_path} is not found") - return target_files[0] + return target_file def log_print(*args, logger=TaskScriptRunner.logger, **kwargs): diff --git a/tests/unit_test/app_common/executors/task_script_runner_test.py b/tests/unit_test/app_common/executors/task_script_runner_test.py index 867d2c4b31..65057e5f01 100644 --- a/tests/unit_test/app_common/executors/task_script_runner_test.py +++ b/tests/unit_test/app_common/executors/task_script_runner_test.py @@ -17,8 +17,17 @@ from nvflare.app_common.executors.task_script_runner import TaskScriptRunner -class TestExecTaskFuncWrapper(unittest.TestCase): +class TestTaskScriptRunner(unittest.TestCase): def test_app_scripts_and_args(self): + curr_dir = os.getcwd() + script_path = "nvflare/cli.py" + script_args = "--batch_size 4" + wrapper = TaskScriptRunner(script_path=script_path, script_args=script_args) + + self.assertTrue(wrapper.script_path.endswith(script_path)) + self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "cli.py"), "--batch_size", "4"]) + + def test_app_scripts_and_args2(self): curr_dir = os.getcwd() script_path = "cli.py" script_args = "--batch_size 4" @@ -26,3 +35,40 @@ def test_app_scripts_and_args(self): self.assertTrue(wrapper.script_path.endswith(script_path)) self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "cli.py"), "--batch_size", "4"]) + + def test_app_scripts_with_sub_dirs1(self): + curr_dir = os.getcwd() + script_path = "nvflare/__init__.py" + wrapper = TaskScriptRunner(script_path=script_path) + + self.assertTrue(wrapper.script_path.endswith(script_path)) + self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "__init__.py")]) + + def test_app_scripts_with_sub_dirs2(self): + curr_dir = os.getcwd() + script_path = "nvflare/app_common/executors/__init__.py" + wrapper = TaskScriptRunner(script_path=script_path) + + self.assertTrue(wrapper.script_path.endswith(script_path)) + self.assertEqual( + wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "app_common", "executors", "__init__.py")] + ) + + def test_app_scripts_with_sub_dirs3(self): + curr_dir = os.getcwd() + script_path = "executors/task_script_runner.py" + wrapper = TaskScriptRunner(script_path=script_path) + + self.assertTrue(wrapper.script_path.endswith(script_path)) + self.assertEqual( + wrapper.get_sys_argv(), + [os.path.join(curr_dir, "nvflare", "app_common", "executors", "task_script_runner.py")], + ) + + def test_app_scripts_with_sub_dirs4(self): + curr_dir = os.getcwd() + script_path = "in_process/api.py" + wrapper = TaskScriptRunner(script_path=script_path) + + self.assertTrue(wrapper.script_path.endswith(script_path)) + self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "client", "in_process", "api.py")])