Skip to content

Commit

Permalink
handle cases where the script with relative path in Script Runner (#2543
Browse files Browse the repository at this point in the history
)

* handle cases where the script with relative path

* handle cases where the script with relative path

* add more unit test cases and change the file search logics

* code format

* add more unit test cases and change the file search logics
  • Loading branch information
chesterxgchen authored May 2, 2024
1 parent 372624d commit 3dd4476
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 6 deletions.
25 changes: 20 additions & 5 deletions nvflare/app_common/executors/task_script_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,29 @@ def get_sys_argv(self):
return [self.script_path] + args_list

def get_script_full_path(self, script_path) -> str:
target_files = None
target_file = None
script_filename = os.path.basename(script_path)
script_dirs = os.path.dirname(script_path)

for r, dirs, files in os.walk(os.getcwd()):
target_files = [os.path.join(r, f) for f in files if f == script_path]
if target_files:
for f in files:
absolute_path = os.path.join(r, f)
if absolute_path.endswith(script_path):
parent_dir = absolute_path[: absolute_path.find(script_path)].rstrip(os.sep)
if os.path.isdir(parent_dir):
target_file = absolute_path
break

if not script_dirs and f == script_filename:
target_file = absolute_path
break

if target_file:
break
if not target_files:

if not target_file:
raise ValueError(f"{script_path} is not found")
return target_files[0]
return target_file


def log_print(*args, logger=TaskScriptRunner.logger, **kwargs):
Expand Down
48 changes: 47 additions & 1 deletion tests/unit_test/app_common/executors/task_script_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,58 @@
from nvflare.app_common.executors.task_script_runner import TaskScriptRunner


class TestExecTaskFuncWrapper(unittest.TestCase):
class TestTaskScriptRunner(unittest.TestCase):
def test_app_scripts_and_args(self):
curr_dir = os.getcwd()
script_path = "nvflare/cli.py"
script_args = "--batch_size 4"
wrapper = TaskScriptRunner(script_path=script_path, script_args=script_args)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "cli.py"), "--batch_size", "4"])

def test_app_scripts_and_args2(self):
curr_dir = os.getcwd()
script_path = "cli.py"
script_args = "--batch_size 4"
wrapper = TaskScriptRunner(script_path=script_path, script_args=script_args)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "cli.py"), "--batch_size", "4"])

def test_app_scripts_with_sub_dirs1(self):
curr_dir = os.getcwd()
script_path = "nvflare/__init__.py"
wrapper = TaskScriptRunner(script_path=script_path)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "__init__.py")])

def test_app_scripts_with_sub_dirs2(self):
curr_dir = os.getcwd()
script_path = "nvflare/app_common/executors/__init__.py"
wrapper = TaskScriptRunner(script_path=script_path)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertEqual(
wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "app_common", "executors", "__init__.py")]
)

def test_app_scripts_with_sub_dirs3(self):
curr_dir = os.getcwd()
script_path = "executors/task_script_runner.py"
wrapper = TaskScriptRunner(script_path=script_path)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertEqual(
wrapper.get_sys_argv(),
[os.path.join(curr_dir, "nvflare", "app_common", "executors", "task_script_runner.py")],
)

def test_app_scripts_with_sub_dirs4(self):
curr_dir = os.getcwd()
script_path = "in_process/api.py"
wrapper = TaskScriptRunner(script_path=script_path)

self.assertTrue(wrapper.script_path.endswith(script_path))
self.assertEqual(wrapper.get_sys_argv(), [os.path.join(curr_dir, "nvflare", "client", "in_process", "api.py")])

0 comments on commit 3dd4476

Please sign in to comment.