Skip to content

Commit

Permalink
add type specification for arguments + add more detailed DocStrings f…
Browse files Browse the repository at this point in the history
…or paramter validation functions
  • Loading branch information
danrgll committed Dec 6, 2023
1 parent c70b85e commit 91d8a45
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 34 deletions.
9 changes: 7 additions & 2 deletions neps/search_spaces/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import OrderedDict
from copy import deepcopy
from itertools import product
from pathlib import Path

import ConfigSpace as CS
import numpy as np
Expand Down Expand Up @@ -66,7 +67,11 @@ def pipeline_space_from_configspace(
return pipeline_space


def pipeline_space_from_yaml(yaml_file_path):
def pipeline_space_from_yaml(
yaml_file_path: str | Path,
) -> dict[
str, FloatParameter | IntegerParameter | CategoricalParameter | ConstantParameter
]:
"""
Reads configuration details from a YAML file and constructs a pipeline space
dictionary.
Expand All @@ -76,7 +81,7 @@ def pipeline_space_from_yaml(yaml_file_path):
maps parameter names to their respective configuration objects.
Args:
yaml_file_path (Union[str, Path]): Path to the YAML file containing parameter
yaml_file_path (str | Path): Path to the YAML file containing parameter
configurations.
Returns:
Expand Down
165 changes: 133 additions & 32 deletions neps/search_spaces/yaml_search_space_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,36 @@
from __future__ import annotations

import re


def convert_scientific_notation(value, show_usage_flag=False):
"""Check if the value is a string that matches scientific e notation and convert it
to float. (specially numbers like 3.3e-5 with a float value in front, which yaml
can not interpret directly as float)."""
def convert_scientific_notation(
value: str | int | float, show_usage_flag=False
) -> float | (float, bool):
"""
Convert a given value to a float if it's a string that matches scientific e notation.
This is especially useful for numbers like "3.3e-5" which YAML parsers may not
directly interpret as floats.
If the 'show_usage_flag' is set to True, the function returns a tuple of the float
conversion and a boolean flag indicating whether scientific notation was detected.
Args:
value (str | int | float): The value to convert. Can be an integer, float,
or a string representing a number, possibly in
scientific notation.
show_usage_flag (bool): Optional; defaults to False. If True, the function
also returns a flag indicating whether scientific
notation was detected in the string.
Returns:
float: The value converted to float if 'show_usage_flag' is False.
(float, bool): A tuple containing the value converted to float and a flag
indicating scientific notation detection if 'show_usage_flag'
is True.
Raises:
ValueError: If the value is a string and does not represent a valid number.
"""

e_notation_pattern = r"^-?\d+(\.\d+)?[eE]-?\d+$"

Expand Down Expand Up @@ -54,7 +80,9 @@ def __init__(self, exception):
super().__init__(self.message)


def deduce_and_validate_param_type(name, details):
def deduce_and_validate_param_type(
name: str, details: dict[str, str | int | float]
) -> str:
"""
Deduces the parameter type from details and validates them.
Expand Down Expand Up @@ -82,16 +110,18 @@ def deduce_and_validate_param_type(name, details):
return param_type


def deduce_param_type(name, details):
def deduce_param_type(name: str, details: dict[str, int | str | float]) -> str:
"""Deduces the parameter type based on the provided details.
This function analyzes the provided details dictionary to determine the type of
parameter. It supports identifying integer, float, categorical, and constant
parameter types.
The function interprets the 'details' dictionary to determine the parameter type.
The dictionary should include key-value pairs that describe the parameter's
characteristics, such as lower, upper, default value, or possible choices.
Args:
name (str): The name of the parameter.
details (dict): A dictionary containing parameter specifications.
details ((dict[str, int | str | float])): A dictionary containing parameter
specifications.
Returns:
str: The deduced parameter type ('int', 'float', 'categorical', or 'constant').
Expand All @@ -112,12 +142,19 @@ def deduce_param_type(name, details):
elif isinstance(details["lower"], float) and isinstance(details["upper"], float):
param_type = "float"
else:
details["lower"], flag_lower = convert_scientific_notation(
details["lower"], show_usage_flag=True
)
details["upper"], flag_upper = convert_scientific_notation(
details["upper"], show_usage_flag=True
)
try:
details["lower"], flag_lower = convert_scientific_notation(
details["lower"], show_usage_flag=True
)
details["upper"], flag_upper = convert_scientific_notation(
details["upper"], show_usage_flag=True
)
except ValueError as e:
raise TypeError(
f"Inconsistent types for 'lower' and 'upper' in '{name}'. "
f"Both must be either integers or floats."
) from e

# check if one value is e notation and if so convert it to float
if flag_lower or flag_upper:
param_type = "float"
Expand Down Expand Up @@ -145,7 +182,9 @@ def deduce_param_type(name, details):
return param_type


def validate_param_details(name, param_type, details):
def validate_param_details(
name: str, param_type: str, details: dict[str, int | str | float]
):
"""
Validates the details of a parameter based on its type.
Expand All @@ -166,9 +205,6 @@ def validate_param_details(name, param_type, details):
the necessary keys in the 'details' are missing based on the parameter type.
TypeError: If the 'param_type' is not one of the supported types.
Returns:
str: The parameter type in lowercase.
Example Usage:
validate_param_details("learning_rate", "float", {"lower": 0.01, "upper": 0.1,
"default": 0.05})
Expand Down Expand Up @@ -203,11 +239,28 @@ def validate_param_details(name, param_type, details):
"For categorical parameter: cat, categorical\n"
"For constant parameter: const, constant\n"
)
return param_type


def validate_integer_parameter(name, details):
"""validate int parameter and convert e notation values to int"""
def validate_integer_parameter(name: str, details: dict[str, str | int | float]):
"""
Validates and processes an integer parameter's details, converting scientific
notation to integers where necessary.
This function checks the type of 'lower' and 'upper', and the 'default'
value (if present) for an integer parameter. It also handles conversion of values
in scientific notation (e.g., 1e2) to integers.
Args:
name (str): The name of the integer parameter.
details (dict[str, str | int | float]): A dictionary containing the parameter's
specifications. Expected keys include
'lower', 'upper', and optionally 'default',
among others.
Raises:
TypeError: If 'lower', 'upper', or 'default' are not valid integers or cannot
be converted from scientific notation to integers.
"""
# check if all keys are allowed to use and if the mandatory ones are provided
check_keys(
name,
Expand Down Expand Up @@ -254,8 +307,24 @@ def validate_integer_parameter(name, details):
) from e


def validate_float_parameter(name, details):
"""validate float parameter and convert e notation values to float"""
def validate_float_parameter(name: str, details: dict[str, str | int | float]):
"""
Validates and processes a float parameter's details, converting scientific
notation values to float where necessary.
This function checks the type of 'lower' and 'upper', and the 'default'
value (if present) for a float parameter. It handles conversion of values in
scientific notation (e.g., 1e-5) to float.
Args:
name: The name of the float parameter.
details: A dictionary containing the parameter's specifications. Expected keys
include 'lower', 'upper', and optionally 'default', among others.
Raises:
TypeError: If 'lower', 'upper', or 'default' are not valid floats or cannot
be converted from scientific notation to floats.
"""
# check if all keys are allowed to use and if the mandatory ones are provided
check_keys(
name,
Expand Down Expand Up @@ -284,8 +353,23 @@ def validate_float_parameter(name, details):
) from e


def validate_categorical_parameter(name, details):
"""validate categorical parameter and convert e notation values to float"""
def validate_categorical_parameter(name: str, details: dict[str, str | int | float]):
"""
Validates a categorical parameter, including conversion of scientific notation
values to floats within the choices.
This function ensures that the 'choices' key in the details is a list and attempts
to convert any elements in scientific notation to floats. It also handles the
'default' value, converting it from scientific notation if necessary.
Args:
name: The name of the categorical parameter.
details: A dictionary containing the parameter's specifications. Required key
is 'choices', with 'default' being optional.
Raises:
TypeError: If 'choices' is not a list
"""
# check if all keys are allowed to use and if the mandatory ones are provided
check_keys(
name,
Expand Down Expand Up @@ -320,8 +404,20 @@ def validate_categorical_parameter(name, details):
details["default"] = default


def validate_constant_parameter(name, details):
"""Validate constant parameter and convert e notation to float"""
def validate_constant_parameter(name: str, details: dict[str, str | int | float]):
"""
Validates a constant parameter, including conversion of values in scientific
notation to floats.
This function checks the 'value' key in the details dictionary and converts any
value expressed in scientific notation to a float. It ensures that the mandatory
'value' key is provided and appropriately formatted.
Args:
name: The name of the constant parameter.
details: A dictionary containing the parameter's specifications. The required
key is 'value'.
"""
# check if all keys are allowed to use and if the mandatory ones are provided
check_keys(name, details, {"value", "type", "is_fidelity"}, {"value"})

Expand All @@ -339,22 +435,27 @@ def validate_constant_parameter(name, details):
details["value"] = converted_value


def check_keys(name, my_dict, allowed_keys, mandatory_keys):
def check_keys(
name: str,
details: dict[str, str | int | float],
allowed_keys: set,
mandatory_keys: set,
):
"""
Checks if all keys in 'my_dict' are contained in the set 'allowed_keys' and
if all keys in 'mandatory_keys' are present in 'my_dict'.
Raises an exception if an unallowed key is found or if a mandatory key is missing.
"""
# Check for unallowed keys
unallowed_keys = [key for key in my_dict if key not in allowed_keys]
unallowed_keys = [key for key in details if key not in allowed_keys]
if unallowed_keys:
unallowed_keys_str = ", ".join(unallowed_keys)
raise KeyError(
f"Unallowed key(s) '{unallowed_keys_str}' found for parameter '" f"{name}'."
)

# Check for missing mandatory keys
missing_mandatory_keys = [key for key in mandatory_keys if key not in my_dict]
missing_mandatory_keys = [key for key in mandatory_keys if key not in details]
if missing_mandatory_keys:
missing_keys_str = ", ".join(missing_mandatory_keys)
raise KeyError(
Expand Down

0 comments on commit 91d8a45

Please sign in to comment.