Skip to content

Commit

Permalink
Add explicit tests for quantized values with llamatune
Browse files Browse the repository at this point in the history
  • Loading branch information
bpkroth committed Jul 23, 2024
1 parent 60bd4f0 commit ef06e8c
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@
from mlos_core.spaces.adapters import LlamaTuneAdapter


def construct_parameter_space(
# Explicitly test quantized values with llamatune space adapter.
# TODO: Add log scale sampling tests as well.


def construct_parameter_space( # pylint: disable=too-many-arguments
*,
n_continuous_params: int = 0,
n_quantized_continuous_params: int = 0,
n_integer_params: int = 0,
n_quantized_integer_params: int = 0,
n_categorical_params: int = 0,
seed: int = 1234,
) -> CS.ConfigurationSpace:
Expand All @@ -28,10 +35,18 @@ def construct_parameter_space(
input_space.add_hyperparameter(
CS.UniformFloatHyperparameter(name=f"cont_{idx}", lower=0, upper=64)
)
for idx in range(n_quantized_continuous_params):
input_space.add_hyperparameter(
CS.UniformFloatHyperparameter(name=f"cont_{idx}", lower=0, upper=64, q=12.8)
)
for idx in range(n_integer_params):
input_space.add_hyperparameter(
CS.UniformIntegerHyperparameter(name=f"int_{idx}", lower=-1, upper=256)
)
for idx in range(n_quantized_integer_params):
input_space.add_hyperparameter(
CS.UniformIntegerHyperparameter(name=f"int_{idx}", lower=0, upper=256, q=16)
)
for idx in range(n_categorical_params):
input_space.add_hyperparameter(
CS.CategoricalHyperparameter(
Expand All @@ -53,6 +68,13 @@ def construct_parameter_space(
{"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)},
{"n_integer_params": int(num_target_space_dims * num_orig_space_factor)},
{"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)},
{"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)},
{"n_quantized_integer_params": int(num_target_space_dims * num_orig_space_factor)},
{
"n_quantized_continuous_params": int(
num_target_space_dims * num_orig_space_factor
)
},
# Mix of all three types
{
"n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3),
Expand Down Expand Up @@ -374,6 +396,12 @@ def test_max_unique_values_per_param() -> None:
{"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)},
{"n_integer_params": int(num_target_space_dims * num_orig_space_factor)},
{"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)},
{"n_quantized_integer_params": int(num_target_space_dims * num_orig_space_factor)},
{
"n_quantized_continuous_params": int(
num_target_space_dims * num_orig_space_factor
)
},
# Mix of all three types
{
"n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3),
Expand Down

0 comments on commit ef06e8c

Please sign in to comment.