Skip to content

Commit

Permalink
adapt tests and documentation for yaml_search_space
Browse files Browse the repository at this point in the history
  • Loading branch information
danrgll committed Nov 25, 2023
1 parent 155ed07 commit a62c1cc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
28 changes: 20 additions & 8 deletions src/neps/search_spaces/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def pipeline_space_from_yaml(yaml_file_path):
Raises:
KeyError: If any mandatory configuration for a parameter is missing in the YAML file.
ValueError: If lower and upper are not the same type of value
TypeError: If lower and upper are not the same type of value
ValueError: if choices is not a list
ValueError: If an unknown parameter type is encountered.
ValueError: If YAML file is incorrectly constructed
KeyError: If an unknown parameter type is encountered.
KeyError: If YAML file is incorrectly constructed
"""
# Load the YAML file
try:
Expand All @@ -86,16 +86,24 @@ def pipeline_space_from_yaml(yaml_file_path):
except yaml.YAMLError as e:
raise ValueError(f"The file at {yaml_file_path} is not a valid YAML file.") from e

# check for key config_space
# check for key search_space
if "search_space" not in config:
raise ValueError(
"The YAML file is incorrectly constructed: 'config_space' key is missing."
raise KeyError(
"The YAML file is incorrectly constructed: the 'search_space:' "
"reference is missing at the top of the file."
)

# Initialize the pipeline space
pipeline_space = {}
# Iterate over the items in the YAML configuration
for name, details in config["search_space"].items():
if not (isinstance(name, str) and isinstance(details, dict)):
raise KeyError(
f"Invalid format for {name} in YAML file. "
f"Expected 'name' as string and corresponding 'details' as a dictionary. "
f"Found 'name' type: {type(name).__name__}, 'details' type:"
f" {type(details).__name__}."
)
if "lower" in details and "upper" in details:
# Determine if it's an integer or float range parameter
if isinstance(details["lower"], int) and isinstance(details["upper"], int):
Expand All @@ -105,7 +113,7 @@ def pipeline_space_from_yaml(yaml_file_path):
):
param_type = FloatParameter
else:
raise ValueError(
raise TypeError(
f"Inconsistent types for 'lower' and 'upper' in '{name}'. "
f"Both must be either integers or floats."
)
Expand Down Expand Up @@ -136,8 +144,12 @@ def pipeline_space_from_yaml(yaml_file_path):
else:
# Handle unknown parameter types
raise KeyError(
f"Unsupported parameter format for '{name}'. "
f"Unsupported parameter format for '{name}'."
f"Expected keys not found in {details}."
"Supported parameters:"
"Float and Integer: Expected keys: 'lower', 'upper'"
"Categorical: Expected keys: 'choices'"
"Constant: Expected keys: 'value'"
)

return pipeline_space
Expand Down
10 changes: 9 additions & 1 deletion tests/test_yaml_search_space/test_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ def test_correct_yaml_file():
assert pipeline_space["learning_rate"].default is None
assert pipeline_space["learning_rate"].default_confidence_score == 0.5
assert isinstance(pipeline_space["num_epochs"], IntegerParameter)
assert pipeline_space["num_epochs"].lower == 3
assert pipeline_space["num_epochs"].upper == 30
assert pipeline_space["num_epochs"].log is False
assert pipeline_space["num_epochs"].is_fidelity is True
assert pipeline_space["num_epochs"].default is None
assert pipeline_space["num_epochs"].default_confidence_score == 0.5
assert isinstance(pipeline_space["optimizer"], CategoricalParameter)
assert pipeline_space["optimizer"].choices == ["adam", "sgd", "rmsprop"]
assert pipeline_space["optimizer"].is_fidelity is False
assert pipeline_space["optimizer"].default is None
assert pipeline_space["optimizer"].default_confidence_score == 2
Expand All @@ -47,11 +50,14 @@ def test_correct_including_priors_yaml_file():
assert pipeline_space["learning_rate"].default == 0.001
assert pipeline_space["learning_rate"].default_confidence_score == 0.125
assert isinstance(pipeline_space["num_epochs"], IntegerParameter)
assert pipeline_space["num_epochs"].lower == 3
assert pipeline_space["num_epochs"].upper == 30
assert pipeline_space["num_epochs"].log is False
assert pipeline_space["num_epochs"].is_fidelity is True
assert pipeline_space["num_epochs"].default == 10
assert pipeline_space["num_epochs"].default_confidence_score == 0.25
assert isinstance(pipeline_space["optimizer"], CategoricalParameter)
assert pipeline_space["optimizer"].choices == ["adam", "sgd", "rmsprop"]
assert pipeline_space["optimizer"].is_fidelity is False
assert pipeline_space["optimizer"].default == "sgd"
assert pipeline_space["optimizer"].default_confidence_score == 4
Expand All @@ -78,7 +84,9 @@ def test_yaml_file_with_missing_key():
def test_yaml_file_with_inconsistent_types():
"""Test the function with a YAML file having inconsistent types for
'lower' and 'upper'."""
with pytest.raises(ValueError):
with pytest.raises(TypeError):
pipeline_space_from_yaml(
"tests/test_yaml_search_space/inconsistent_types_config.yml"
)


0 comments on commit a62c1cc

Please sign in to comment.