From 2723c09d63d41371622ae6827c33111bf6766571 Mon Sep 17 00:00:00 2001 From: Daniel <63580393+danrgll@users.noreply.github.com> Date: Wed, 6 Dec 2023 15:04:27 +0100 Subject: [PATCH] made code more readable for validate parameter inputs + add tests + comment functions --- src/neps/search_spaces/search_space.py | 39 +- .../search_spaces/yaml_search_space_utils.py | 344 ++++++++++-------- .../inconsistent_types_config.yml | 4 +- .../inconsistent_types_config2.yml | 18 + .../test_search_space.py | 5 + 5 files changed, 239 insertions(+), 171 deletions(-) create mode 100644 tests/test_yaml_search_space/inconsistent_types_config2.yml diff --git a/src/neps/search_spaces/search_space.py b/src/neps/search_spaces/search_space.py index 534ca041..92856910 100644 --- a/src/neps/search_spaces/search_space.py +++ b/src/neps/search_spaces/search_space.py @@ -68,35 +68,38 @@ def pipeline_space_from_configspace( def pipeline_space_from_yaml(yaml_file_path): """ - Reads configuration details from a YAML file and creates a dictionary of parameters. + Reads configuration details from a YAML file and constructs a pipeline space + dictionary. - This function parses a YAML file to extract configuration details and organizes them - into a dictionary. Each key in the dictionary corresponds to a parameter name, and - the value is an object representing the parameter configuration. + This function extracts parameter configurations from a YAML file, validating and + translating them into corresponding parameter objects. The resulting dictionary + maps parameter names to their respective configuration objects. Args: - yaml_file_path (str): Path to the YAML file containing configuration details. + yaml_file_path (str): Path to the YAML file containing parameter configurations. Returns: - dict: A dictionary with parameter names as keys and parameter objects as values. + dict: A dictionary where keys are parameter names and values are parameter + objects (like IntegerParameter, FloatParameter, etc.). Raises: - SearchSpaceFromYamlFileError: Wraps and re-raises exceptions (KeyError, TypeError, - ValueError) that occur during the initialization of the search space from the YAML - file. This custom exception class provides additional context about the error, - enhancing diagnostic clarity and simplifying error handling for function callers. - It includes the type of the original exception and a descriptive message, thereby - localizing error handling to this specific function and preventing the propagation - of these generic exceptions. + SearchSpaceFromYamlFileError: This custom exception is raised if there are issues + with the YAML file's format or contents. It encapsulates underlying exceptions + (KeyError, TypeError, ValueError) that occur during the processing of the YAML + file. This approach localizes error handling, providing clearer context and + diagnostics. The raised exception includes the type of the original error and + a descriptive message. Note: - The YAML file must be structured correctly with appropriate keys and values for - each parameter type. The function validates the structure and content of the YAML - file, raising specific errors for missing mandatory configuration details, type - mismatches, and unknown parameter types. + The YAML file should be properly structured with valid keys and values as per the + expected parameter types. The function employs modular validation and type + deduction logic, ensuring each parameter's configuration adheres to expected + formats and constraints. Any deviation results in an appropriately raised error, + which is then captured by SearchSpaceFromYamlFileError for streamlined error + handling. Example: - Given a YAML file 'config.yaml', call the function as follows: + To use this function with a YAML file 'config.yaml', you can do: pipeline_space = pipeline_space_from_yaml('config.yaml') """ try: diff --git a/src/neps/search_spaces/yaml_search_space_utils.py b/src/neps/search_spaces/yaml_search_space_utils.py index afb040f0..b3789842 100644 --- a/src/neps/search_spaces/yaml_search_space_utils.py +++ b/src/neps/search_spaces/yaml_search_space_utils.py @@ -8,7 +8,7 @@ def convert_scientific_notation(value, show_usage_flag=False): e_notation_pattern = r"^-?\d+(\.\d+)?[eE]-?\d+$" - flag = False # Check if e notation was detected + flag = False # Flag if e notation was detected if isinstance(value, str): # Remove all whitespace from the string @@ -76,7 +76,7 @@ def deduce_and_validate_param_type(name, details): # Logic to infer type if not explicitly provided param_type = deduce_param_type(name, details) - # Validate details based on deduced type + # Validate details of a parameter based on (deduced) type validate_param_details(name, param_type, details) return param_type @@ -103,6 +103,8 @@ def deduce_param_type(name, details): Example: param_type = deduce_param_type('example_param', {'lower': 0, 'upper': 10})""" # Logic to deduce type from details + + # check for int and float conditions if "lower" in details and "upper" in details: # Determine if it's an integer or float range parameter if isinstance(details["lower"], int) and isinstance(details["upper"], int): @@ -116,7 +118,7 @@ def deduce_param_type(name, details): details["upper"], flag_upper = convert_scientific_notation( details["upper"], show_usage_flag=True ) - # check if one value is 10^format to convert it to float + # check if one value is e notation and if so convert it to float if flag_lower or flag_upper: param_type = "float" else: @@ -124,9 +126,11 @@ def deduce_param_type(name, details): f"Inconsistent types for 'lower' and 'upper' in '{name}'. " f"Both must be either integers or floats." ) - + # check for categorical condition elif "choices" in details: param_type = "categorical" + + # check for constant condition elif "value" in details: param_type = "constant" else: @@ -142,6 +146,33 @@ def deduce_param_type(name, details): def validate_param_details(name, param_type, details): + """ + Validates the details of a parameter based on its type. + + This function checks the format and type-specific details of a parameter + specified in a YAML file. It ensures that the 'name' of the parameter is a string + and its 'details' are provided as a dictionary. Depending on the parameter type, + it delegates the validation to the appropriate type-specific validation function. + + Parameters: + name (str): The name of the parameter. It should be a string. + param_type (str): The type of the parameter. Supported types are 'int' (or 'integer'), + 'float', 'cat' (or 'categorical'), and 'const' (or 'constant'). + details (dict): The detailed configuration of the parameter, which includes its + attributes like 'lower', 'upper', 'default', etc. + + Raises: + KeyError: If the 'name' is not a string or 'details' is not a dictionary, or if + the necessary keys in the 'details' are missing based on the parameter type. + TypeError: If the 'param_type' is not one of the supported types. + + Returns: + str: The parameter type in lowercase. + + Example Usage: + validate_param_details("learning_rate", "float", {"lower": 0.01, "upper": 0.1, + "default": 0.05}) + """ if not (isinstance(name, str) and isinstance(details, dict)): raise KeyError( f"Invalid format for {name} in YAML file. " @@ -152,155 +183,20 @@ def validate_param_details(name, param_type, details): param_type = param_type.lower() # init parameter by checking type if param_type in ("int", "integer"): - # check if all keys are allowed - check_allowed_keys( - name, - details, - { - "lower", - "upper", - "type", - "log", - "is_fidelity", - "default", - "default_confidence", - }, - ) - # Check Integer Parameter - if "lower" not in details or "upper" not in details: - raise KeyError( - f"Missing 'lower' or 'upper' for integer " f"parameter '{name}'." - ) - if not isinstance(details["lower"], int) or not isinstance(details["upper"], int): - try: - # for numbers like 1e2 and 10^ - lower, flag_lower = convert_scientific_notation( - details["lower"], show_usage_flag=True - ) - upper, flag_upper = convert_scientific_notation( - details["upper"], show_usage_flag=True - ) - # check if one value format is e or 10^ and if its an integer - if flag_lower or flag_upper: - if lower == int(lower) and upper == int(upper): - details["lower"] = int(lower) - details["upper"] = int(upper) - else: - raise ValueError() - else: - raise ValueError() - except ValueError as e: - raise TypeError( - f"'lower' and 'upper' must be integer for " - f"integer parameter '{name}'." - ) from e - if "default" in details: - if not isinstance(details["default"], int): - default = convert_scientific_notation(details["default"]) - if default == int(default): - details["default"] = int(default) - else: - raise TypeError( - f"default value {details['default']} " - f"must be integer for integer parameter {name}" - ) + validate_integer_parameter(name, details) elif param_type == "float": - # check if all keys are allowed - check_allowed_keys( - name, - details, - { - "lower", - "upper", - "type", - "log", - "is_fidelity", - "default", - "default_confidence", - }, - ) - # Check Float Parameter - if "lower" not in details or "upper" not in details: - raise KeyError( - f"Missing key 'lower' or 'upper' for float " f"parameter '{name}'." - ) - if not isinstance(details["lower"], float) or not isinstance( - details["upper"], float - ): - try: - # for numbers like 1e-5 and 10^ - details["lower"] = convert_scientific_notation(details["lower"]) - details["upper"] = convert_scientific_notation(details["upper"]) - except ValueError as e: - raise TypeError( - f"'lower' and 'upper' must be integer for " - f"integer parameter '{name}'." - ) from e - if "default" in details: - if not isinstance(details["default"], float): - try: - details["default"] = convert_scientific_notation(details["default"]) - except ValueError as e: - raise TypeError( - f" 'default' must be float for float parameter " f"{name} " - ) from e + validate_float_parameter(name, details) elif param_type in ("cat", "categorical"): - # check if all keys are allowed - check_allowed_keys( - name, - details, - {"choices", "type", "is_fidelity", "default", "default_confidence"}, - ) - # Check Categorical parameter - if "choices" not in details: - raise KeyError(f"Missing key 'choices' for categorical " f"parameter {name}") - if not isinstance(details["choices"], (list, tuple)): - raise TypeError(f"The 'choices' for '{name}' must be a list or tuple.") - for i, element in enumerate(details["choices"]): - try: - converted_value, e_flag = convert_scientific_notation( - element, show_usage_flag=True - ) - if e_flag: - details["choices"][ - i - ] = converted_value # Replace the element at the same position - except ValueError: - pass # If a ValueError occurs, simply continue to the next element - if "default" in details: - e_flag = False - try: - # check if e notation, if then convert to number - default, e_flag = convert_scientific_notation( - details["default"], show_usage_flag=True - ) - except ValueError: - pass - if e_flag is True: - details["default"] = default - elif param_type in ("const", "constant"): - # check if all keys are allowed - check_allowed_keys(name, details, {"value", "type", "is_fidelity"}) - # Check Constant parameter - if "value" not in details: - raise KeyError(f"Missing key 'value' for constant parameter " f"{name}") - else: - e_flag = False - try: - converted_value, e_flag = convert_scientific_notation( - details["value"], show_usage_flag=True - ) - except ValueError: - pass - if e_flag: - details["value"] = converted_value + validate_categorical_parameter(name, details) + elif param_type in ("const", "constant"): + validate_constant_parameter(name, details) else: # Handle unknown parameter types raise TypeError( - f"Unsupported parameter type{details['type']} for '{name}'.\n" + f"Unsupported parameter type'{details['type']}' for '{name}'.\n" f"Supported Types for argument type are:\n" "For integer parameter: int, integer\n" "For float parameter: float\n" @@ -310,11 +206,157 @@ def validate_param_details(name, param_type, details): return param_type -def check_allowed_keys(name, my_dict, allowed_keys): +def validate_integer_parameter(name, details): + """validate int parameter and convert e notation values to int""" + # check if all keys are allowed to use and if the mandatory ones are provided + check_keys( + name, + details, + {"lower", "upper", "type", "log", "is_fidelity", "default", "default_confidence"}, + {"lower", "upper"}, + ) + + if not isinstance(details["lower"], int) or not isinstance(details["upper"], int): + try: + # for numbers like 1e2 and 10^ + lower, flag_lower = convert_scientific_notation( + details["lower"], show_usage_flag=True + ) + upper, flag_upper = convert_scientific_notation( + details["upper"], show_usage_flag=True + ) + # check if one value format is e notation and if its an integer + if flag_lower or flag_upper: + if lower == int(lower) and upper == int(upper): + details["lower"] = int(lower) + details["upper"] = int(upper) + else: + raise TypeError() + else: + raise TypeError() + except (ValueError, TypeError) as e: + raise TypeError( + f"'lower' and 'upper' must be integer for " f"integer parameter '{name}'." + ) from e + if "default" in details: + if not isinstance(details["default"], int): + try: + # convert value can raise ValueError + default = convert_scientific_notation(details["default"]) + if default == int(default): + details["default"] = int(default) + else: + raise TypeError() # type of value is not int + except (ValueError, TypeError) as e: + raise TypeError( + f"default value {details['default']} " + f"must be integer for integer parameter {name}" + ) from e + + +def validate_float_parameter(name, details): + """validate float parameter and convert e notation values to float""" + # check if all keys are allowed to use and if the mandatory ones are provided + check_keys( + name, + details, + {"lower", "upper", "type", "log", "is_fidelity", "default", "default_confidence"}, + {"lower", "upper"}, + ) + + if not isinstance(details["lower"], float) or not isinstance(details["upper"], float): + try: + # for numbers like 1e-5 and 10^ + details["lower"] = convert_scientific_notation(details["lower"]) + details["upper"] = convert_scientific_notation(details["upper"]) + except ValueError as e: + raise TypeError( + f"'lower' and 'upper' must be integer for " f"integer parameter '{name}'." + ) from e + if "default" in details: + if not isinstance(details["default"], float): + try: + details["default"] = convert_scientific_notation(details["default"]) + except ValueError as e: + raise TypeError( + f" default'{details['default']}' must be float for float " + f"parameter {name} " + ) from e + + +def validate_categorical_parameter(name, details): + """validate categorical parameter and convert e notation values to float""" + # check if all keys are allowed to use and if the mandatory ones are provided + check_keys( + name, + details, + {"choices", "type", "is_fidelity", "default", "default_confidence"}, + {"choices"}, + ) + + if not isinstance(details["choices"], list): + raise TypeError(f"The 'choices' for '{name}' must be a list.") + for i, element in enumerate(details["choices"]): + try: + converted_value, e_flag = convert_scientific_notation( + element, show_usage_flag=True + ) + if e_flag: + details["choices"][ + i + ] = converted_value # Replace the element at the same position + except ValueError: + pass # If a ValueError occurs, simply continue to the next element + if "default" in details: + e_flag = False + try: + # check if e notation, if then convert to number + default, e_flag = convert_scientific_notation( + details["default"], show_usage_flag=True + ) + except ValueError: + pass # if default value is not in a numeric format, Value Error occurs + if e_flag is True: + details["default"] = default + + +def validate_constant_parameter(name, details): + """Validate constant parameter and convert e notation to float""" + # check if all keys are allowed to use and if the mandatory ones are provided + check_keys(name, details, {"value", "type", "is_fidelity"}, {"value"}) + + # check for e notation and convert it to float + e_flag = False + try: + converted_value, e_flag = convert_scientific_notation( + details["value"], show_usage_flag=True + ) + except ValueError: + # if the value is not able to convert to float a ValueError get raised by + # convert_scientific_notation function + pass + if e_flag: + details["value"] = converted_value + + +def check_keys(name, my_dict, allowed_keys, mandatory_keys): """ - Checks if all keys in 'my_dict' are contained in the set 'allowed_keys'. - If an unallowed key is found, an exception is raised. + Checks if all keys in 'my_dict' are contained in the set 'allowed_keys' and + if all keys in 'mandatory_keys' are present in 'my_dict'. + Raises an exception if an unallowed key is found or if a mandatory key is missing. """ - for key in my_dict: - if key not in allowed_keys: - raise KeyError(f"This key is not allowed: '{key}' for parameter '{name}'") + # Check for unallowed keys + unallowed_keys = [key for key in my_dict if key not in allowed_keys] + if unallowed_keys: + unallowed_keys_str = ", ".join(unallowed_keys) + raise KeyError( + f"Unallowed key(s) '{unallowed_keys_str}' found for parameter '" f"{name}'." + ) + + # Check for missing mandatory keys + missing_mandatory_keys = [key for key in mandatory_keys if key not in my_dict] + if missing_mandatory_keys: + missing_keys_str = ", ".join(missing_mandatory_keys) + raise KeyError( + f"Missing mandatory key(s) '{missing_keys_str}' for parameter '" f"{name}'." + ) diff --git a/tests/test_yaml_search_space/inconsistent_types_config.yml b/tests/test_yaml_search_space/inconsistent_types_config.yml index 3d5eb559..5c3182a2 100644 --- a/tests/test_yaml_search_space/inconsistent_types_config.yml +++ b/tests/test_yaml_search_space/inconsistent_types_config.yml @@ -1,7 +1,7 @@ search_space: learning_rate: - lower: "0.00001" # Lower is now a string - upper: 0.1 + lower: "string" # Lower is now a string + upper: 1e3 log: true num_epochs: diff --git a/tests/test_yaml_search_space/inconsistent_types_config2.yml b/tests/test_yaml_search_space/inconsistent_types_config2.yml new file mode 100644 index 00000000..5f205e92 --- /dev/null +++ b/tests/test_yaml_search_space/inconsistent_types_config2.yml @@ -0,0 +1,18 @@ +search_space: + learning_rate: + type: int + lower: 2.3 # float + upper: 1e3 + log: true + + num_epochs: + lower: 3 + upper: 30 + is_fidelity: True + + optimizer: + choices: ["adam", "sgd", "rmsprop"] + + dropout_rate: + value: 0.5 + is_fidelity: True diff --git a/tests/test_yaml_search_space/test_search_space.py b/tests/test_yaml_search_space/test_search_space.py index aecf617a..4624bd78 100644 --- a/tests/test_yaml_search_space/test_search_space.py +++ b/tests/test_yaml_search_space/test_search_space.py @@ -112,6 +112,11 @@ def test_yaml_file_with_inconsistent_types(): "tests/test_yaml_search_space/inconsistent_types_config.yml" ) assert str(excinfo.value.exception_type == "TypeError") + with pytest.raises(SearchSpaceFromYamlFileError) as excinfo: + pipeline_space_from_yaml( + "tests/test_yaml_search_space/inconsistent_types_config2.yml" + ) + assert str(excinfo.value.exception_type == "TypeError") @pytest.mark.neps_api