Skip to content

Commit

Permalink
enable usage of Path object for yaml_file config_space
Browse files Browse the repository at this point in the history
  • Loading branch information
danrgll committed Dec 6, 2023
1 parent 15d4cc8 commit c70b85e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion neps/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _run_args(
if isinstance(pipeline_space, CS.ConfigurationSpace):
pipeline_space = pipeline_space_from_configspace(pipeline_space)
# Support pipeline space as YAML file
elif isinstance(pipeline_space, str):
elif isinstance(pipeline_space, (str, Path)):
pipeline_space = pipeline_space_from_yaml(pipeline_space)

# Support pipeline space as mix of ConfigurationSpace and neps parameters
Expand Down
5 changes: 3 additions & 2 deletions neps/search_spaces/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def pipeline_space_from_yaml(yaml_file_path):
maps parameter names to their respective configuration objects.
Args:
yaml_file_path (str): Path to the YAML file containing parameter configurations.
yaml_file_path (Union[str, Path]): Path to the YAML file containing parameter
configurations.
Returns:
dict: A dictionary where keys are parameter names and values are parameter
Expand Down Expand Up @@ -109,7 +110,7 @@ def pipeline_space_from_yaml(yaml_file_path):
config = yaml.safe_load(file)
except yaml.YAMLError as e:
raise ValueError(
f"The file at {yaml_file_path} is not a valid YAML file."
f"The file at {str(yaml_file_path)} is not a valid YAML file."
) from e

# check for init key search_space
Expand Down
8 changes: 6 additions & 2 deletions tests/test_yaml_search_space/test_search_space.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import pytest
from neps.search_spaces.search_space import (
SearchSpaceFromYamlFileError,
Expand Down Expand Up @@ -94,7 +96,9 @@ def test_correct_including_priors_yaml_file():
def test_incorrect_yaml_file():
"""Test the function with an incorrectly formatted YAML file."""
with pytest.raises(SearchSpaceFromYamlFileError) as excinfo:
pipeline_space_from_yaml("tests/test_yaml_search_space/incorrect_config.txt")
pipeline_space_from_yaml(
Path("tests/test_yaml_search_space/incorrect_config.txt")
)
assert str(excinfo.value.exception_type == "ValueError")


Expand All @@ -117,7 +121,7 @@ def test_yaml_file_with_inconsistent_types():
assert str(excinfo.value.exception_type == "TypeError")
with pytest.raises(SearchSpaceFromYamlFileError) as excinfo:
pipeline_space_from_yaml(
"tests/test_yaml_search_space/inconsistent_types_config2.yml"
Path("tests/test_yaml_search_space/inconsistent_types_config2.yml")
)
assert str(excinfo.value.exception_type == "TypeError")

Expand Down

0 comments on commit c70b85e

Please sign in to comment.