diff --git a/neps/api.py b/neps/api.py index 5cf870cd..58f3df01 100644 --- a/neps/api.py +++ b/neps/api.py @@ -318,7 +318,7 @@ def _run_args( if isinstance(pipeline_space, CS.ConfigurationSpace): pipeline_space = pipeline_space_from_configspace(pipeline_space) # Support pipeline space as YAML file - elif isinstance(pipeline_space, str): + elif isinstance(pipeline_space, (str, Path)): pipeline_space = pipeline_space_from_yaml(pipeline_space) # Support pipeline space as mix of ConfigurationSpace and neps parameters diff --git a/neps/search_spaces/search_space.py b/neps/search_spaces/search_space.py index 92856910..d43fd684 100644 --- a/neps/search_spaces/search_space.py +++ b/neps/search_spaces/search_space.py @@ -76,7 +76,8 @@ def pipeline_space_from_yaml(yaml_file_path): maps parameter names to their respective configuration objects. Args: - yaml_file_path (str): Path to the YAML file containing parameter configurations. + yaml_file_path (Union[str, Path]): Path to the YAML file containing parameter + configurations. Returns: dict: A dictionary where keys are parameter names and values are parameter @@ -109,7 +110,7 @@ def pipeline_space_from_yaml(yaml_file_path): config = yaml.safe_load(file) except yaml.YAMLError as e: raise ValueError( - f"The file at {yaml_file_path} is not a valid YAML file." + f"The file at {str(yaml_file_path)} is not a valid YAML file." ) from e # check for init key search_space diff --git a/tests/test_yaml_search_space/test_search_space.py b/tests/test_yaml_search_space/test_search_space.py index 5f01118d..f354b128 100644 --- a/tests/test_yaml_search_space/test_search_space.py +++ b/tests/test_yaml_search_space/test_search_space.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest from neps.search_spaces.search_space import ( SearchSpaceFromYamlFileError, @@ -94,7 +96,9 @@ def test_correct_including_priors_yaml_file(): def test_incorrect_yaml_file(): """Test the function with an incorrectly formatted YAML file.""" with pytest.raises(SearchSpaceFromYamlFileError) as excinfo: - pipeline_space_from_yaml("tests/test_yaml_search_space/incorrect_config.txt") + pipeline_space_from_yaml( + Path("tests/test_yaml_search_space/incorrect_config.txt") + ) assert str(excinfo.value.exception_type == "ValueError") @@ -117,7 +121,7 @@ def test_yaml_file_with_inconsistent_types(): assert str(excinfo.value.exception_type == "TypeError") with pytest.raises(SearchSpaceFromYamlFileError) as excinfo: pipeline_space_from_yaml( - "tests/test_yaml_search_space/inconsistent_types_config2.yml" + Path("tests/test_yaml_search_space/inconsistent_types_config2.yml") ) assert str(excinfo.value.exception_type == "TypeError")