From 91d8a457db76cee1970fbabd6941325c656b758e Mon Sep 17 00:00:00 2001 From: Daniel <63580393+danrgll@users.noreply.github.com> Date: Thu, 7 Dec 2023 00:08:25 +0100 Subject: [PATCH] add type specification for arguments + add more detailed DocStrings for paramter validation functions --- neps/search_spaces/search_space.py | 9 +- neps/search_spaces/yaml_search_space_utils.py | 165 ++++++++++++++---- 2 files changed, 140 insertions(+), 34 deletions(-) diff --git a/neps/search_spaces/search_space.py b/neps/search_spaces/search_space.py index d43fd684..bcf5b18b 100644 --- a/neps/search_spaces/search_space.py +++ b/neps/search_spaces/search_space.py @@ -6,6 +6,7 @@ from collections import OrderedDict from copy import deepcopy from itertools import product +from pathlib import Path import ConfigSpace as CS import numpy as np @@ -66,7 +67,11 @@ def pipeline_space_from_configspace( return pipeline_space -def pipeline_space_from_yaml(yaml_file_path): +def pipeline_space_from_yaml( + yaml_file_path: str | Path, +) -> dict[ + str, FloatParameter | IntegerParameter | CategoricalParameter | ConstantParameter +]: """ Reads configuration details from a YAML file and constructs a pipeline space dictionary. @@ -76,7 +81,7 @@ def pipeline_space_from_yaml(yaml_file_path): maps parameter names to their respective configuration objects. Args: - yaml_file_path (Union[str, Path]): Path to the YAML file containing parameter + yaml_file_path (str | Path): Path to the YAML file containing parameter configurations. Returns: diff --git a/neps/search_spaces/yaml_search_space_utils.py b/neps/search_spaces/yaml_search_space_utils.py index b3789842..fe278dd6 100644 --- a/neps/search_spaces/yaml_search_space_utils.py +++ b/neps/search_spaces/yaml_search_space_utils.py @@ -1,10 +1,36 @@ +from __future__ import annotations + import re -def convert_scientific_notation(value, show_usage_flag=False): - """Check if the value is a string that matches scientific e notation and convert it - to float. (specially numbers like 3.3e-5 with a float value in front, which yaml - can not interpret directly as float).""" +def convert_scientific_notation( + value: str | int | float, show_usage_flag=False +) -> float | (float, bool): + """ + Convert a given value to a float if it's a string that matches scientific e notation. + This is especially useful for numbers like "3.3e-5" which YAML parsers may not + directly interpret as floats. + + If the 'show_usage_flag' is set to True, the function returns a tuple of the float + conversion and a boolean flag indicating whether scientific notation was detected. + + Args: + value (str | int | float): The value to convert. Can be an integer, float, + or a string representing a number, possibly in + scientific notation. + show_usage_flag (bool): Optional; defaults to False. If True, the function + also returns a flag indicating whether scientific + notation was detected in the string. + + Returns: + float: The value converted to float if 'show_usage_flag' is False. + (float, bool): A tuple containing the value converted to float and a flag + indicating scientific notation detection if 'show_usage_flag' + is True. + + Raises: + ValueError: If the value is a string and does not represent a valid number. + """ e_notation_pattern = r"^-?\d+(\.\d+)?[eE]-?\d+$" @@ -54,7 +80,9 @@ def __init__(self, exception): super().__init__(self.message) -def deduce_and_validate_param_type(name, details): +def deduce_and_validate_param_type( + name: str, details: dict[str, str | int | float] +) -> str: """ Deduces the parameter type from details and validates them. @@ -82,16 +110,18 @@ def deduce_and_validate_param_type(name, details): return param_type -def deduce_param_type(name, details): +def deduce_param_type(name: str, details: dict[str, int | str | float]) -> str: """Deduces the parameter type based on the provided details. - This function analyzes the provided details dictionary to determine the type of - parameter. It supports identifying integer, float, categorical, and constant - parameter types. + The function interprets the 'details' dictionary to determine the parameter type. + The dictionary should include key-value pairs that describe the parameter's + characteristics, such as lower, upper, default value, or possible choices. + Args: name (str): The name of the parameter. - details (dict): A dictionary containing parameter specifications. + details ((dict[str, int | str | float])): A dictionary containing parameter + specifications. Returns: str: The deduced parameter type ('int', 'float', 'categorical', or 'constant'). @@ -112,12 +142,19 @@ def deduce_param_type(name, details): elif isinstance(details["lower"], float) and isinstance(details["upper"], float): param_type = "float" else: - details["lower"], flag_lower = convert_scientific_notation( - details["lower"], show_usage_flag=True - ) - details["upper"], flag_upper = convert_scientific_notation( - details["upper"], show_usage_flag=True - ) + try: + details["lower"], flag_lower = convert_scientific_notation( + details["lower"], show_usage_flag=True + ) + details["upper"], flag_upper = convert_scientific_notation( + details["upper"], show_usage_flag=True + ) + except ValueError as e: + raise TypeError( + f"Inconsistent types for 'lower' and 'upper' in '{name}'. " + f"Both must be either integers or floats." + ) from e + # check if one value is e notation and if so convert it to float if flag_lower or flag_upper: param_type = "float" @@ -145,7 +182,9 @@ def deduce_param_type(name, details): return param_type -def validate_param_details(name, param_type, details): +def validate_param_details( + name: str, param_type: str, details: dict[str, int | str | float] +): """ Validates the details of a parameter based on its type. @@ -166,9 +205,6 @@ def validate_param_details(name, param_type, details): the necessary keys in the 'details' are missing based on the parameter type. TypeError: If the 'param_type' is not one of the supported types. - Returns: - str: The parameter type in lowercase. - Example Usage: validate_param_details("learning_rate", "float", {"lower": 0.01, "upper": 0.1, "default": 0.05}) @@ -203,11 +239,28 @@ def validate_param_details(name, param_type, details): "For categorical parameter: cat, categorical\n" "For constant parameter: const, constant\n" ) - return param_type -def validate_integer_parameter(name, details): - """validate int parameter and convert e notation values to int""" +def validate_integer_parameter(name: str, details: dict[str, str | int | float]): + """ + Validates and processes an integer parameter's details, converting scientific + notation to integers where necessary. + + This function checks the type of 'lower' and 'upper', and the 'default' + value (if present) for an integer parameter. It also handles conversion of values + in scientific notation (e.g., 1e2) to integers. + + Args: + name (str): The name of the integer parameter. + details (dict[str, str | int | float]): A dictionary containing the parameter's + specifications. Expected keys include + 'lower', 'upper', and optionally 'default', + among others. + + Raises: + TypeError: If 'lower', 'upper', or 'default' are not valid integers or cannot + be converted from scientific notation to integers. + """ # check if all keys are allowed to use and if the mandatory ones are provided check_keys( name, @@ -254,8 +307,24 @@ def validate_integer_parameter(name, details): ) from e -def validate_float_parameter(name, details): - """validate float parameter and convert e notation values to float""" +def validate_float_parameter(name: str, details: dict[str, str | int | float]): + """ + Validates and processes a float parameter's details, converting scientific + notation values to float where necessary. + + This function checks the type of 'lower' and 'upper', and the 'default' + value (if present) for a float parameter. It handles conversion of values in + scientific notation (e.g., 1e-5) to float. + + Args: + name: The name of the float parameter. + details: A dictionary containing the parameter's specifications. Expected keys + include 'lower', 'upper', and optionally 'default', among others. + + Raises: + TypeError: If 'lower', 'upper', or 'default' are not valid floats or cannot + be converted from scientific notation to floats. + """ # check if all keys are allowed to use and if the mandatory ones are provided check_keys( name, @@ -284,8 +353,23 @@ def validate_float_parameter(name, details): ) from e -def validate_categorical_parameter(name, details): - """validate categorical parameter and convert e notation values to float""" +def validate_categorical_parameter(name: str, details: dict[str, str | int | float]): + """ + Validates a categorical parameter, including conversion of scientific notation + values to floats within the choices. + + This function ensures that the 'choices' key in the details is a list and attempts + to convert any elements in scientific notation to floats. It also handles the + 'default' value, converting it from scientific notation if necessary. + + Args: + name: The name of the categorical parameter. + details: A dictionary containing the parameter's specifications. Required key + is 'choices', with 'default' being optional. + + Raises: + TypeError: If 'choices' is not a list + """ # check if all keys are allowed to use and if the mandatory ones are provided check_keys( name, @@ -320,8 +404,20 @@ def validate_categorical_parameter(name, details): details["default"] = default -def validate_constant_parameter(name, details): - """Validate constant parameter and convert e notation to float""" +def validate_constant_parameter(name: str, details: dict[str, str | int | float]): + """ + Validates a constant parameter, including conversion of values in scientific + notation to floats. + + This function checks the 'value' key in the details dictionary and converts any + value expressed in scientific notation to a float. It ensures that the mandatory + 'value' key is provided and appropriately formatted. + + Args: + name: The name of the constant parameter. + details: A dictionary containing the parameter's specifications. The required + key is 'value'. + """ # check if all keys are allowed to use and if the mandatory ones are provided check_keys(name, details, {"value", "type", "is_fidelity"}, {"value"}) @@ -339,14 +435,19 @@ def validate_constant_parameter(name, details): details["value"] = converted_value -def check_keys(name, my_dict, allowed_keys, mandatory_keys): +def check_keys( + name: str, + details: dict[str, str | int | float], + allowed_keys: set, + mandatory_keys: set, +): """ Checks if all keys in 'my_dict' are contained in the set 'allowed_keys' and if all keys in 'mandatory_keys' are present in 'my_dict'. Raises an exception if an unallowed key is found or if a mandatory key is missing. """ # Check for unallowed keys - unallowed_keys = [key for key in my_dict if key not in allowed_keys] + unallowed_keys = [key for key in details if key not in allowed_keys] if unallowed_keys: unallowed_keys_str = ", ".join(unallowed_keys) raise KeyError( @@ -354,7 +455,7 @@ def check_keys(name, my_dict, allowed_keys, mandatory_keys): ) # Check for missing mandatory keys - missing_mandatory_keys = [key for key in mandatory_keys if key not in my_dict] + missing_mandatory_keys = [key for key in mandatory_keys if key not in details] if missing_mandatory_keys: missing_keys_str = ", ".join(missing_mandatory_keys) raise KeyError(