Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bundle root directory to Python search directories automatically #6910

Merged
merged 19 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import os
import sys
import time
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -224,23 +225,23 @@ def __init__(
super().__init__(workflow_type=workflow_type)
if config_file is not None:
_config_files = ensure_tuple(config_file)
config_root_path = Path(_config_files[0]).parent
self.config_root_path = Path(_config_files[0]).parent
for _config_file in _config_files:
_config_file = Path(_config_file)
if _config_file.parent != config_root_path:
if _config_file.parent != self.config_root_path:
warnings.warn(
f"Not all config files are in {config_root_path}. If logging_file and meta_file are"
f"not specified, {config_root_path} will be used as the default config root directory."
f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are"
f"not specified, {self.config_root_path} will be used as the default config root directory."
)
if not _config_file.is_file():
raise FileNotFoundError(f"Cannot find the config file: {_config_file}.")
else:
config_root_path = Path("configs")
self.config_root_path = Path("configs")

logging_file = str(config_root_path / "logging.conf") if logging_file is None else logging_file
logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file
if logging_file is not None:
if not os.path.exists(logging_file):
if logging_file == str(config_root_path / "logging.conf"):
if logging_file == str(self.config_root_path / "logging.conf"):
warnings.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
else:
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
Expand All @@ -250,7 +251,7 @@ def __init__(

self.parser = ConfigParser()
self.parser.read_config(f=config_file)
meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file
meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file
if isinstance(meta_file, str) and not os.path.exists(meta_file):
raise FileNotFoundError(f"Cannot find the metadata config file: {meta_file}.")
else:
Expand Down Expand Up @@ -285,6 +286,10 @@ def run(self) -> Any:
Run the bundle workflow, it can be a training, evaluation or inference.

"""
_bundle_root_path = (
self.config_root_path.parent if self.config_root_path.name == "configs" else self.config_root_path
)
sys.path.append(str(_bundle_root_path))
wyli marked this conversation as resolved.
Show resolved Hide resolved
if self.run_id not in self.parser:
raise ValueError(f"run ID '{self.run_id}' doesn't exist in the config file.")
return self._run_expr(id=self.run_id)
Expand Down
72 changes: 72 additions & 0 deletions tests/test_integration_bundle_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import os
import shutil
import subprocess
import sys
import tempfile
import unittest
Expand Down Expand Up @@ -44,6 +45,14 @@ def run(self):
return self.val


class _Runnable43:
def __init__(self, func):
self.func = func

def run(self):
self.func()


class TestBundleRun(unittest.TestCase):
def setUp(self):
self.data_dir = tempfile.mkdtemp()
Expand Down Expand Up @@ -77,6 +86,69 @@ def test_tiny(self):
with self.assertRaises(RuntimeError):
# test wrong run_id="run"
command_line_tests(cmd + ["run", "run", "--config_file", config_file])
with self.assertRaises(RuntimeError):
# test missing meta file
command_line_tests(cmd + ["run", "training", "--config_file", config_file])

def test_scripts_fold(self):
# test scripts directory has been added to Python search directories automatically
config_file = os.path.join(self.data_dir, "tiny_config.json")
meta_file = os.path.join(self.data_dir, "tiny_meta.json")
scripts_dir = os.path.join(self.data_dir, "scripts")
script_file = os.path.join(scripts_dir, "test_scripts_fold.py")
init_file = os.path.join(scripts_dir, "__init__.py")

with open(config_file, "w") as f:
json.dump(
{
"imports": ["$import scripts"],
"trainer": {
"_target_": "tests.test_integration_bundle_run._Runnable43",
"func": "$scripts.tiny_test",
},
# keep this test case to cover the "runner_id" arg
"training": "$@trainer.run()",
},
f,
)
with open(meta_file, "w") as f:
json.dump(
{"version": "0.1.0", "monai_version": "1.1.0", "pytorch_version": "1.13.1", "numpy_version": "1.22.2"},
f,
)

os.mkdir(scripts_dir)
script_file_lines = ["def tiny_test():\n", " print('successfully added scripts fold!') \n"]
init_file_line = "from .test_scripts_fold import tiny_test\n"
with open(script_file, "w") as f:
f.writelines(script_file_lines)
f.close()
with open(init_file, "w") as f:
f.write(init_file_line)
f.close()

cmd = ["coverage", "run", "-m", "monai.bundle"]
# test both CLI entry "run" and "run_workflow"
expected_condition = "successfully added scripts fold!"
command_run = cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file]
completed_process = subprocess.run(command_run, check=True, capture_output=True, text=True)
output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output
print(output)

self.assertTrue(expected_condition in output)
command_run_workflow = cmd + [
"run_workflow",
"--run_id",
"training",
"--config_file",
config_file,
"--meta_file",
meta_file,
]
completed_process = subprocess.run(command_run_workflow, check=True, capture_output=True, text=True)
output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output
print(output)
self.assertTrue(expected_condition in output)

with self.assertRaises(RuntimeError):
# test missing meta file
Expand Down