Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing tweaks #696

Merged
merged 12 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlos_bench/mlos_bench/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def target(self) -> str:
return self._opt_target

@property
def direction(self) -> str:
def direction(self) -> Literal['min', 'max']:
"""
The direction to optimize the target metric (e.g., min or max).
"""
Expand Down
4 changes: 2 additions & 2 deletions mlos_bench/mlos_bench/storage/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pandas

from mlos_bench.tunables.tunable import TunableValue
from mlos_bench.tunables.tunable import TunableValue, TunableValueTypeTuple
from mlos_bench.util import try_parse_val


Expand All @@ -32,7 +32,7 @@ def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValu
data = {}
for _, row in dataframe.astype('O').iterrows():
assert isinstance(row['parameter'], str)
assert row['value'] is None or isinstance(row['value'], (str, int, float))
assert isinstance(row['value'], TunableValueTypeTuple)
bpkroth marked this conversation as resolved.
Show resolved Hide resolved
if row['parameter'] in data:
raise ValueError(f"Duplicate parameter '{row['parameter']}' in dataframe")
data[row['parameter']] = try_parse_val(row['value']) if isinstance(row['value'], str) else row['value']
Expand Down
36 changes: 18 additions & 18 deletions mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json5 as json
import pytest

from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable import Tunable, TunableValueTypeName


def test_tunable_name() -> None:
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_categorical_tunable_disallow_repeats() -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_disallow_null_default(tunable_type: str) -> None:
def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeName) -> None:
"""
Disallow null values as default for numerical tunables.
"""
Expand All @@ -148,7 +148,7 @@ def test_numerical_tunable_disallow_null_default(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_disallow_out_of_range(tunable_type: str) -> None:
def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeName) -> None:
"""
Disallow out of range values as default for numerical tunables.
"""
Expand All @@ -161,7 +161,7 @@ def test_numerical_tunable_disallow_out_of_range(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_wrong_params(tunable_type: str) -> None:
def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> None:
"""
Disallow values param for numerical tunables.
"""
Expand All @@ -175,7 +175,7 @@ def test_numerical_tunable_wrong_params(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_required_params(tunable_type: str) -> None:
def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) -> None:
"""
Disallow null values param for numerical tunables.
"""
Expand All @@ -192,7 +192,7 @@ def test_numerical_tunable_required_params(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_invalid_range(tunable_type: str) -> None:
def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> None:
"""
Disallow invalid range param for numerical tunables.
"""
Expand All @@ -209,7 +209,7 @@ def test_numerical_tunable_invalid_range(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_tunable_reversed_range(tunable_type: str) -> None:
def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> None:
"""
Disallow reverse range param for numerical tunables.
"""
Expand All @@ -226,7 +226,7 @@ def test_numerical_tunable_reversed_range(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights(tunable_type: str) -> None:
def test_numerical_weights(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with weighted special values.
"""
Expand All @@ -248,7 +248,7 @@ def test_numerical_weights(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_quantization(tunable_type: str) -> None:
def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with quantization.
"""
Expand All @@ -267,7 +267,7 @@ def test_numerical_quantization(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_log(tunable_type: str) -> None:
def test_numerical_log(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with log scale.
"""
Expand All @@ -285,7 +285,7 @@ def test_numerical_log(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_no_specials(tunable_type: str) -> None:
def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> None:
"""
Raise an error if special_weights are specified but no special values.
"""
Expand All @@ -303,7 +303,7 @@ def test_numerical_weights_no_specials(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_non_normalized(tunable_type: str) -> None:
def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with non-normalized weights
of the special values.
Expand All @@ -326,7 +326,7 @@ def test_numerical_weights_non_normalized(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_wrong_count(tunable_type: str) -> None:
def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with incorrect number of weights.
"""
Expand All @@ -346,7 +346,7 @@ def test_numerical_weights_wrong_count(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_no_range_weight(tunable_type: str) -> None:
def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with weights but no range_weight.
"""
Expand All @@ -365,7 +365,7 @@ def test_numerical_weights_no_range_weight(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_range_weight_no_weights(tunable_type: str) -> None:
def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with specials but no range_weight.
"""
Expand All @@ -384,7 +384,7 @@ def test_numerical_range_weight_no_weights(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_range_weight_no_specials(tunable_type: str) -> None:
def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with specials but no range_weight.
"""
Expand All @@ -402,7 +402,7 @@ def test_numerical_range_weight_no_specials(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_weights_wrong_values(tunable_type: str) -> None:
def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> None:
"""
Try to instantiate a numerical tunable with incorrect number of weights.
"""
Expand All @@ -422,7 +422,7 @@ def test_numerical_weights_wrong_values(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_quantization_wrong(tunable_type: str) -> None:
def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> None:
"""
Instantiate a numerical tunable with invalid number of quantization points.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json5 as json
import pytest

from mlos_bench.tunables.tunable import Tunable
from mlos_bench.tunables.tunable import Tunable, TunableValueTypeName


def test_categorical_distribution() -> None:
Expand All @@ -28,7 +28,7 @@ def test_categorical_distribution() -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_uniform(tunable_type: str) -> None:
def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> None:
"""
Create a numeric Tunable with explicit uniform distribution.
"""
Expand All @@ -46,7 +46,7 @@ def test_numerical_distribution_uniform(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_normal(tunable_type: str) -> None:
def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> None:
"""
Create a numeric Tunable with explicit Gaussian distribution specified.
"""
Expand All @@ -67,7 +67,7 @@ def test_numerical_distribution_normal(tunable_type: str) -> None:


@pytest.mark.parametrize("tunable_type", ["int", "float"])
def test_numerical_distribution_beta(tunable_type: str) -> None:
def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None:
"""
Create a numeric Tunable with explicit Beta distribution specified.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,19 @@ def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None:
"""
Checks that we can't use null/None in integer tunables.
"""
with pytest.raises(TypeError):
with pytest.raises((TypeError, AssertionError)):
tunable_int.value = None
with pytest.raises(TypeError):
with pytest.raises((TypeError, AssertionError)):
tunable_int.numerical_value = None # type: ignore[assignment]


def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None:
"""
Checks that we can't use null/None in float tunables.
"""
with pytest.raises(TypeError):
with pytest.raises((TypeError, AssertionError)):
tunable_float.value = None
with pytest.raises(TypeError):
with pytest.raises((TypeError, AssertionError)):
tunable_float.numerical_value = None # type: ignore[assignment]


Expand Down
27 changes: 22 additions & 5 deletions mlos_bench/mlos_bench/tunables/tunable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,22 @@
"""A tunable parameter value type alias."""
TunableValue = Union[int, float, Optional[str]]

"""Tunable value type."""
TunableValueType = Union[Type[int], Type[float], Type[str]]

"""
Tunable value type tuple.
For checking with isinstance()
"""
TunableValueTypeTuple = (int, float, str, type(None))

"""The string name of a tunable value type."""
TunableValueTypeName = Literal["int", "float", "categorical"]

"""Tunable values dictionary type"""
TunableValuesDict = Dict[str, TunableValue]

"""Tunable value distribution type"""
DistributionName = Literal["uniform", "normal", "beta"]


Expand All @@ -38,7 +54,7 @@ class TunableDict(TypedDict, total=False):
These are the types expected to be received from the json config.
"""

type: str
type: TunableValueTypeName
description: Optional[str]
default: TunableValue
values: Optional[List[Optional[str]]]
Expand All @@ -59,7 +75,7 @@ class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-m
"""

# Maps tunable types to their corresponding Python types by name.
_DTYPE: Dict[str, Type] = {
_DTYPE: Dict[TunableValueTypeName, TunableValueType] = {
"int": int,
"float": float,
"categorical": str,
Expand All @@ -79,7 +95,7 @@ def __init__(self, name: str, config: TunableDict):
if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema
raise ValueError(f"Invalid name of the tunable: {name}")
self._name = name
self._type = config["type"] # required
self._type: TunableValueTypeName = config["type"] # required
if self._type not in self._DTYPE:
raise ValueError(f"Invalid parameter type: {self._type}")
self._description = config.get("description")
Expand Down Expand Up @@ -302,6 +318,7 @@ def value(self, value: TunableValue) -> TunableValue:
if self.is_categorical and value is None:
coerced_value = None
else:
assert value is not None
coerced_value = self.dtype(value)
except Exception:
_LOG.error("Impossible conversion: %s %s <- %s %s",
Expand Down Expand Up @@ -482,7 +499,7 @@ def range_weight(self) -> Optional[float]:
return self._range_weight

@property
def type(self) -> str:
def type(self) -> TunableValueTypeName:
"""
Get the data type of the tunable.

Expand All @@ -494,7 +511,7 @@ def type(self) -> str:
return self._type

@property
def dtype(self) -> Type:
def dtype(self) -> TunableValueType:
"""
Get the actual Python data type of the tunable.

Expand Down
Loading